diff --git a/go.mod b/go.mod index d6f6cbe..a02093e 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/google/go-containerregistry v0.20.6 github.com/gpustack/gguf-parser-go v0.22.1 github.com/sirupsen/logrus v1.9.3 + github.com/spf13/cobra v1.9.1 ) require ( @@ -16,6 +17,7 @@ require ( github.com/docker/distribution v2.8.3+incompatible // indirect github.com/docker/docker-credential-helpers v0.9.3 // indirect github.com/henvic/httpretty v0.1.4 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect @@ -25,6 +27,7 @@ require ( github.com/opencontainers/image-spec v1.1.1 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d // indirect + github.com/spf13/pflag v1.0.6 // indirect github.com/stretchr/testify v1.10.0 // indirect github.com/vbatts/tar-split v0.12.1 // indirect golang.org/x/crypto v0.35.0 // indirect diff --git a/go.sum b/go.sum index d900d86..88cf1b5 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,6 @@ github.com/containerd/stargz-snapshotter/estargz v0.16.3 h1:7evrXtoh1mSbGj/pfRccTampEyKpjpOnS3CyiV1Ebr8= github.com/containerd/stargz-snapshotter/estargz v0.16.3/go.mod h1:uyr4BfYfOj3G9WBVE8cOlQmXAbPN9VEQpBBeJIuOipU= +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -18,6 +19,8 @@ github.com/gpustack/gguf-parser-go v0.22.1 h1:FRnEDWqT0Rcplr/R9ctCRSN2+3DhVsf6dn github.com/gpustack/gguf-parser-go v0.22.1/go.mod h1:y4TwTtDqFWTK+xvprOjRUh+dowgU2TKCX37vRKvGiZ0= github.com/henvic/httpretty v0.1.4 h1:Jo7uwIRWVFxkqOnErcoYfH90o3ddQyVrSANeS4cxYmU= github.com/henvic/httpretty v0.1.4/go.mod h1:Dn60sQTZfbt2dYsdUSNsCljyF4AfdqnuJFDLJA1I4AM= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= @@ -37,10 +40,15 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d h1:3VwvTjiRPA7cqtgOWddEL+JrcijMlXUmj99c/6YyZoY= github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d/go.mod h1:tAG61zBM1DYRaGIPloumExGvScf08oHuo0kFoOqdbT0= +github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo= +github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0= +github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= +github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= diff --git a/tools/benchmarks/parallelget/README.md b/tools/benchmarks/parallelget/README.md new file mode 100644 index 0000000..2db5a75 --- /dev/null +++ b/tools/benchmarks/parallelget/README.md @@ -0,0 +1,110 @@ +# ParallelGet Benchmark Tool + +A command-line benchmarking tool that compares the performance of standard HTTP GET requests against parallelized requests using the `transport/parallel` package. + +## Features + +- **Performance Comparison**: Downloads the same URL twice (standard vs parallel) and compares timing +- **Response Validation**: Ensures both downloads produce identical results (byte-for-byte comparison) +- **Configurable Parameters**: Adjustable chunk size and concurrency settings +- **Detailed Metrics**: Reports download speeds, timing differences, and performance improvements +- **Dynamic Progress Display**: Shows real-time progress bars during downloads with percentage and byte counts +- **Clean Output**: User-friendly performance summary with emojis and clear formatting + +## Usage + +```bash +go run ./tools/benchmarks/parallelget [flags] +``` + +or + +```bash +go build ./tools/benchmarks/parallelget +./parallelget [flags] +``` + +### Arguments + +- ``: The HTTP URL to benchmark (required) + +### Flags + +- `--chunk-size int`: Minimum chunk size in bytes for parallelization (default: 1MB) +- `--max-concurrent uint`: Maximum concurrent requests for parallel transport (default: 4) +- `-h, --help`: Show help information + +### Examples + +```bash +# Basic usage +./parallelget https://example.com/large-file.zip + +# Custom chunk size (512KB) and higher concurrency +./parallelget https://example.com/large-file.zip --chunk-size 524288 --max-concurrent 8 + +# Small chunk size for testing with smaller files +./parallelget https://httpbin.org/bytes/10485760 --chunk-size 262144 --max-concurrent 6 +``` + +## Output + +The tool provides detailed output including: + +1. **Configuration**: Shows the chunk size and concurrency settings +2. **Progress**: Real-time updates for each benchmark phase +3. **Individual Results**: Download speed and timing for each approach +4. **Validation**: Confirms that both downloads produced identical content +5. **Performance Summary**: + - Speedup factor (e.g., "3.2x faster") + - Time saved/penalty + - Detailed timing breakdown + +### Sample Output + +``` +Benchmarking HTTP GET performance for: https://example.com/large-file.zip +Configuration: chunk-size=1048576 bytes, max-concurrent=4 + +Running non-parallel benchmark... + Progress: [██████████████████████████████] 100.0% (10485760/10485760 bytes) +✓ Non-parallel: 10485760 bytes in 2.1s (4.76 MB/s) +Running parallel benchmark... + Progress: [██████████████████████████████] 100.0% (10485760/10485760 bytes) +✓ Parallel: 10485760 bytes in 650ms (15.38 MB/s) +Validating response consistency... +✓ Responses match perfectly + +============================================================ +PERFORMANCE COMPARISON +============================================================ +🚀 Parallel was 3.23x faster than non-parallel +⏱️ Time saved: 1.45s (69.0%) + +Detailed timing: + Non-parallel: 2.1s + Parallel: 650ms + Difference: -1.45s +``` + +## How It Works + +1. **Non-Parallel Benchmark**: Uses `net/http.DefaultClient` with `net/http.DefaultTransport` +2. **Parallel Benchmark**: Uses `net/http.DefaultClient` with `transport/parallel.ParallelTransport` wrapping `net/http.DefaultTransport` +3. **Response Storage**: Both responses are written to temporary files for validation +4. **Validation**: Performs byte-by-byte comparison to ensure identical content +5. **Cleanup**: Automatically removes temporary files after completion + +## Notes + +- The tool requires the server to support HTTP range requests (`Accept-Ranges: bytes`) for parallel downloads to work +- If the server doesn't support range requests or the file is too small, the parallel transport will automatically fall back to a single request +- Temporary files are automatically cleaned up, even if the tool exits unexpectedly +- The tool validates that both downloads produce identical results before reporting performance metrics + +## Use Cases + +- **Performance Testing**: Evaluate the effectiveness of parallel downloads for different URLs +- **Configuration Tuning**: Find optimal chunk size and concurrency settings for specific servers or file types +- **Server Compatibility**: Test whether servers properly support range requests +- **Network Optimization**: Understand the impact of parallel downloads on different network conditions diff --git a/tools/benchmarks/parallelget/main.go b/tools/benchmarks/parallelget/main.go new file mode 100644 index 0000000..627cd70 --- /dev/null +++ b/tools/benchmarks/parallelget/main.go @@ -0,0 +1,348 @@ +package main + +import ( + "bytes" + "crypto/sha256" + "fmt" + "io" + "net/http" + "os" + "strings" + "sync" + "time" + + "github.com/spf13/cobra" + + "github.com/docker/model-distribution/transport/parallel" +) + +var ( + minChunkSize int64 + maxConcurrent uint +) + +var rootCmd = &cobra.Command{ + Use: "parallelget ", + Short: "Benchmark parallel vs non-parallel HTTP GET requests", + Long: `parallelget is a benchmarking tool that compares the performance of standard +HTTP GET requests against parallelized requests using the transport/parallel package. + +It downloads the same URL twice - once using the standard HTTP client and once +using a parallel transport - then compares the results and reports performance metrics.`, + Args: cobra.ExactArgs(1), + RunE: runBenchmark, + SilenceUsage: true, +} + +func init() { + rootCmd.Flags().Int64Var(&minChunkSize, "chunk-size", 1024*1024, "Minimum chunk size in bytes for parallelization (default 1MB)") + rootCmd.Flags().UintVar(&maxConcurrent, "max-concurrent", 4, "Maximum concurrent requests for parallel transport (default 4)") +} + +func main() { + if err := rootCmd.Execute(); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } +} + +func runBenchmark(cmd *cobra.Command, args []string) error { + url := args[0] + + fmt.Printf("Benchmarking HTTP GET performance for: %s\n", url) + fmt.Printf("Configuration: chunk-size=%d bytes, max-concurrent=%d\n\n", minChunkSize, maxConcurrent) + + // Create temporary files for storing responses. + nonParallelFile, err := os.CreateTemp("", "benchmark-non-parallel-*.tmp") + if err != nil { + return fmt.Errorf("failed to create temp file for non-parallel response: %w", err) + } + defer func() { + nonParallelFile.Close() + os.Remove(nonParallelFile.Name()) + }() + + parallelFile, err := os.CreateTemp("", "benchmark-parallel-*.tmp") + if err != nil { + return fmt.Errorf("failed to create temp file for parallel response: %w", err) + } + defer func() { + parallelFile.Close() + os.Remove(parallelFile.Name()) + }() + + // Run non-parallel benchmark. + fmt.Println("Running non-parallel benchmark...") + nonParallelDuration, nonParallelSize, err := benchmarkNonParallel(url, nonParallelFile) + if err != nil { + return fmt.Errorf("non-parallel benchmark failed: %w", err) + } + fmt.Printf("✓ Non-parallel: %d bytes in %v (%.2f MB/s)\n", nonParallelSize, nonParallelDuration, + float64(nonParallelSize)/nonParallelDuration.Seconds()/(1024*1024)) + + // Run parallel benchmark. + fmt.Println("Running parallel benchmark...") + parallelDuration, parallelSize, err := benchmarkParallel(url, parallelFile) + if err != nil { + return fmt.Errorf("parallel benchmark failed: %w", err) + } + fmt.Printf("✓ Parallel: %d bytes in %v (%.2f MB/s)\n", parallelSize, parallelDuration, + float64(parallelSize)/parallelDuration.Seconds()/(1024*1024)) + + // Validate responses match. + fmt.Println("Validating response consistency...") + if err := validateResponses(nonParallelFile, parallelFile); err != nil { + return fmt.Errorf("response validation failed: %w", err) + } + fmt.Println("✓ Responses match perfectly") + + // Print performance comparison. + fmt.Println("\n" + strings.Repeat("=", 60)) + fmt.Println("PERFORMANCE COMPARISON") + fmt.Println(strings.Repeat("=", 60)) + + speedup := float64(nonParallelDuration) / float64(parallelDuration) + if speedup > 1.0 { + fmt.Printf("🚀 Parallel was %.2fx faster than non-parallel\n", speedup) + timeSaved := nonParallelDuration - parallelDuration + fmt.Printf("⏱️ Time saved: %v (%.1f%%)\n", timeSaved, (1.0-1.0/speedup)*100) + } else if speedup < 1.0 { + slowdown := 1.0 / speedup + fmt.Printf("⚠️ Parallel was %.2fx slower than non-parallel\n", slowdown) + fmt.Printf("⏱️ Time penalty: %v (%.1f%%)\n", parallelDuration-nonParallelDuration, (slowdown-1.0)*100) + } else { + fmt.Println("📊 Both approaches performed equally") + } + + fmt.Printf("\nDetailed timing:\n") + fmt.Printf(" Non-parallel: %v\n", nonParallelDuration) + fmt.Printf(" Parallel: %v\n", parallelDuration) + fmt.Printf(" Difference: %v\n", parallelDuration-nonParallelDuration) + + return nil +} + +// performHTTPGet executes an HTTP GET request using the specified transport +// and measures the time taken to download the entire response body. +// The response is written to outputFile and progress is displayed during the download. +func performHTTPGet(url string, transport http.RoundTripper, outputFile *os.File) (time.Duration, int64, error) { + client := &http.Client{ + Transport: transport, + } + + start := time.Now() + + resp, err := client.Get(url) + if err != nil { + return 0, 0, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return 0, 0, fmt.Errorf("HTTP %d: %s", resp.StatusCode, resp.Status) + } + + // Create progress writer with content length if available. + contentLength := resp.ContentLength + if contentLength <= 0 { + contentLength = -1 // Unknown size. + } + progressWriter := newProgressWriter(outputFile, contentLength, " Progress") + + written, err := io.Copy(progressWriter, resp.Body) + progressWriter.finish() // Ensure final progress is shown. + + if err != nil { + return 0, 0, err + } + + duration := time.Since(start) + return duration, written, nil +} + +// benchmarkNonParallel performs a standard HTTP GET request using the default transport +// and measures the time taken to download the entire response body. +// The response is written to outputFile and progress is displayed during the download. +func benchmarkNonParallel(url string, outputFile *os.File) (time.Duration, int64, error) { + return performHTTPGet(url, http.DefaultTransport, outputFile) +} + +// benchmarkParallel performs an HTTP GET request using the parallel transport +// and measures the time taken to download the entire response body. +// The parallel transport uses byte-range requests to download chunks concurrently. +// The response is written to outputFile and progress is displayed during the download. +func benchmarkParallel(url string, outputFile *os.File) (time.Duration, int64, error) { + // Create parallel transport with configuration. + parallelTransport := parallel.New( + http.DefaultTransport, + parallel.WithMaxConcurrentPerHost(map[string]uint{"": 0}), + parallel.WithMinChunkSize(minChunkSize), + parallel.WithMaxConcurrentPerRequest(maxConcurrent), + ) + + return performHTTPGet(url, parallelTransport, outputFile) +} + +func validateResponses(file1, file2 *os.File) error { + // Get file sizes first for quick comparison. + stat1, err := file1.Stat() + if err != nil { + return fmt.Errorf("failed to stat non-parallel file: %w", err) + } + + stat2, err := file2.Stat() + if err != nil { + return fmt.Errorf("failed to stat parallel file: %w", err) + } + + // Compare file sizes - if they differ, no need to compute hashes. + if stat1.Size() != stat2.Size() { + return fmt.Errorf("file sizes differ: non-parallel=%d bytes, parallel=%d bytes", + stat1.Size(), stat2.Size()) + } + + // Compute SHA-256 hash for first file. + hash1, err := computeFileHash(file1) + if err != nil { + return fmt.Errorf("failed to compute hash for non-parallel file: %w", err) + } + + // Compute SHA-256 hash for second file. + hash2, err := computeFileHash(file2) + if err != nil { + return fmt.Errorf("failed to compute hash for parallel file: %w", err) + } + + // Compare the hashes. + if !bytes.Equal(hash1, hash2) { + return fmt.Errorf("file contents differ: SHA-256 hashes do not match") + } + + return nil +} + +// computeFileHash computes the SHA-256 hash of a file's contents. +// The file is read from the beginning using a single io.Copy operation for efficiency. +func computeFileHash(file *os.File) ([]byte, error) { + // Seek to beginning of file. + if _, err := file.Seek(0, io.SeekStart); err != nil { + return nil, fmt.Errorf("failed to seek to beginning: %w", err) + } + + // Create SHA-256 hasher. + hasher := sha256.New() + + // Copy entire file content to hasher in a single operation. + _, err := io.Copy(hasher, file) + if err != nil { + return nil, fmt.Errorf("failed to read file for hashing: %w", err) + } + + // Return the computed hash. + return hasher.Sum(nil), nil +} + +// progressWriter wraps an io.Writer and provides progress updates during writes. +// It displays a progress bar with percentage completion and transfer rates, +// updating the display at regular intervals to avoid excessive output. +type progressWriter struct { + // writer is the underlying writer to write data to. + writer io.Writer + // total is the total expected bytes (-1 if unknown). + total int64 + // written is the number of bytes written so far. + written int64 + // lastUpdate is the last time the progress display was updated. + lastUpdate time.Time + // label is the label to display with the progress bar. + label string + // finished indicates whether the progress display has been finalized. + finished bool + // mu protects concurrent access to progress state. + mu sync.Mutex +} + +// newProgressWriter creates a new progress writer that wraps the given writer. +// The total parameter specifies the expected number of bytes (use -1 if unknown). +// The label parameter is displayed alongside the progress bar. +func newProgressWriter(writer io.Writer, total int64, label string) *progressWriter { + return &progressWriter{ + writer: writer, + total: total, + label: label, + lastUpdate: time.Now(), + } +} + +// Write implements io.Writer, writing data to the underlying writer and updating progress. +// Progress is displayed at most every 100ms to avoid overwhelming the terminal with updates. +// The final progress update is handled by the finish() method to ensure clean display. +func (pw *progressWriter) Write(data []byte) (int, error) { + // Write data to the underlying writer first. + n, err := pw.writer.Write(data) + if n > 0 { + pw.mu.Lock() + pw.written += int64(n) + now := time.Now() + + // Update progress every 100ms to balance responsiveness and performance. + // Don't update on completion - let finish() handle the final display. + if now.Sub(pw.lastUpdate) >= 100*time.Millisecond && (pw.total < 0 || pw.written < pw.total) { + pw.printProgress() + pw.lastUpdate = now + } + pw.mu.Unlock() + } + return n, err +} + +// printProgress displays the current progress to the terminal. +// For files with known size, shows a progress bar with percentage and bytes. +// For files with unknown size, shows only the bytes transferred. +// Uses carriage return (\r) to overwrite the previous progress line. +func (pw *progressWriter) printProgress() { + if pw.finished { + return + } + + if pw.total < 0 { + // Unknown total size - just show bytes transferred. + fmt.Printf("\r%s: %d bytes", pw.label, pw.written) + return + } + + // Calculate percentage, capping at 100% to handle edge cases. + percent := float64(pw.written) / float64(pw.total) * 100 + if percent > 100 { + percent = 100 + } + + // Create a visual progress bar using filled and empty characters. + barWidth := 30 + filled := int(percent / 100 * float64(barWidth)) + if filled > barWidth { + filled = barWidth + } + + bar := strings.Repeat("█", filled) + strings.Repeat("░", barWidth-filled) + + // Display progress bar with percentage and byte counts. + fmt.Printf("\r%s: [%s] %.1f%% (%d/%d bytes)", + pw.label, bar, percent, pw.written, pw.total) +} + +// finish completes the progress display by showing the final progress state +// and adding a newline to move the cursor to the next line. +// This ensures the progress bar doesn't interfere with subsequent output. +// It's safe to call multiple times - subsequent calls are ignored. +func (pw *progressWriter) finish() { + pw.mu.Lock() + defer pw.mu.Unlock() + if !pw.finished { + // Display final progress state. + pw.printProgress() + // Move to next line to prevent interference with subsequent output. + fmt.Println() + pw.finished = true + } +} diff --git a/transport/internal/bufferfile/fifo.go b/transport/internal/bufferfile/fifo.go new file mode 100644 index 0000000..1f5dccb --- /dev/null +++ b/transport/internal/bufferfile/fifo.go @@ -0,0 +1,234 @@ +// Package bufferfile provides a FIFO implementation backed by a temporary file +// that supports concurrent reads and writes. +package bufferfile + +import ( + "fmt" + "io" + "os" + "sync" +) + +// FIFO is an io.ReadWriteCloser implementation that supports concurrent +// reads and writes to a temporary file. Reads begin from the start of the file +// and writes always append to the end. The type maintains separate read and write +// positions internally. +type FIFO struct { + // file is the underlying temporary file used for storage. + file *os.File + // mu protects all fields and synchronizes access to the FIFO. + mu sync.Mutex + // cond is used to signal waiting readers when new data becomes available + // or when the write side is closed. + cond *sync.Cond + // readPos tracks the current read position within the file. + readPos int64 + // writePos tracks the current write position within the file + // (always at EOF). + writePos int64 + // closed indicates whether Close() has been called, making the FIFO + // unusable. + closed bool + // writeClosed indicates whether CloseWrite() has been called, meaning + // no more writes will occur but reads can continue until all data is + // consumed. + writeClosed bool + // writeErr holds any persistent write error that should be returned to + // future write operations. + writeErr error +} + +// NewFIFO creates a new FIFO backed by a temporary file. +// The caller is responsible for calling Close() to clean up the temporary +// file. +func NewFIFO() (*FIFO, error) { + return NewFIFOInDir("") +} + +// NewFIFOInDir creates a new FIFO backed by a temporary file in the provided +// directory. If dir is empty, the system temporary directory is used. +// The caller is responsible for calling Close() to clean up the temporary +// file. +func NewFIFOInDir(dir string) (*FIFO, error) { + file, err := os.CreateTemp(dir, "model-buffer-*.tmp") + if err != nil { + return nil, fmt.Errorf("failed to create temporary file in dir: %w", err) + } + + fifo := &FIFO{ + file: file, + readPos: 0, + writePos: 0, + closed: false, + } + fifo.cond = sync.NewCond(&fifo.mu) + + return fifo, nil +} + +// Write implements io.Writer. Writes always append to the end of the file. +// Write is safe for concurrent use with Read. +func (f *FIFO) Write(p []byte) (int, error) { + f.mu.Lock() + defer f.mu.Unlock() + + // Check if FIFO is closed for writing. + if f.closed || f.writeClosed { + return 0, fmt.Errorf("write to closed FIFO") + } + + // Return persistent write error if we have one. + if f.writeErr != nil { + return 0, f.writeErr + } + + // Handle empty writes. + if len(p) == 0 { + return 0, nil + } + + // Seek to current write position (end of file). + _, err := f.file.Seek(f.writePos, io.SeekStart) + if err != nil { + f.writeErr = fmt.Errorf("seek to write position failed: %w", err) + return 0, f.writeErr + } + + // Write the data to the file. + n, err := f.file.Write(p) + if n > 0 { + // Update our write position to track how much data we've written. + f.writePos += int64(n) + // Signal all waiting readers that new data is available. + f.cond.Broadcast() + } + if err != nil { + // Store the error for future write attempts. + f.writeErr = fmt.Errorf("write failed: %w", err) + return n, f.writeErr + } + + return n, nil +} + +// Read implements io.Reader. Reads from the current read position in the file. +// Read blocks until data is available or the FIFO is closed. +// Read is safe for concurrent use with Write. +func (f *FIFO) Read(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + + f.mu.Lock() + defer f.mu.Unlock() + + for { + if f.closed { + // FIFO has been fully closed - file is closed and cleaned up. + // Return EOF immediately since no more data can be read. + return 0, io.EOF + } + + // Calculate how much unread data is available + availableBytes := f.writePos - f.readPos + if availableBytes > 0 { + // Data is available - read it immediately. + return f.readFromFile(p) + } + + // No data currently available - check if writes are finished + if f.writeClosed { + // Write side is closed and no data available - return EOF. + return 0, io.EOF + } + + // No data available and writes are still possible - wait for more + // data. + // The condition variable will be signaled when: + // - New data is written (f.cond.Broadcast() in Write). + // - Write side is closed (f.cond.Broadcast() in CloseWrite). + // - FIFO is fully closed (f.cond.Broadcast() in Close). + f.cond.Wait() + } +} + +// readFromFile performs the actual file read operation. +// Must be called with mutex held. +func (f *FIFO) readFromFile(p []byte) (int, error) { + availableBytes := f.writePos - f.readPos + toRead := int64(len(p)) + if toRead > availableBytes { + toRead = availableBytes + } + + // Seek to current read position + _, err := f.file.Seek(f.readPos, io.SeekStart) + if err != nil { + return 0, fmt.Errorf("seek to read position failed: %w", err) + } + + // Read the data + n, err := f.file.Read(p[:toRead]) + if n > 0 { + f.readPos += int64(n) + } + if err != nil && err != io.EOF { + return n, fmt.Errorf("read failed: %w", err) + } + + return n, nil +} + +// Close closes the FIFO and removes the temporary file. +// Any blocked Read or Write operations will be interrupted. +// Close is safe to call multiple times. +func (f *FIFO) Close() error { + f.mu.Lock() + defer f.mu.Unlock() + + if f.closed { + return nil + } + + f.closed = true + + // Wake up all waiting readers. + f.cond.Broadcast() + + var err error + if f.file != nil { + // Get the file name before closing for cleanup. + fileName := f.file.Name() + + // Close the file (this will interrupt any blocked I/O operations). + if closeErr := f.file.Close(); closeErr != nil { + err = fmt.Errorf("failed to close file: %w", closeErr) + } + + // Remove the temporary file. + if removeErr := os.Remove(fileName); removeErr != nil { + if err != nil { + err = fmt.Errorf("%w; also failed to remove temp file: %v", err, removeErr) + } else { + err = fmt.Errorf("failed to remove temp file: %w", removeErr) + } + } + + f.file = nil + } + + return err +} + +// CloseWrite signals that no more writes will happen. +// Readers can still read remaining data, and will receive EOF when all data +// is consumed. Does not clean up resources - use Close() for that. +func (f *FIFO) CloseWrite() { + f.mu.Lock() + defer f.mu.Unlock() + + f.writeClosed = true + + // Wake up all waiting readers to check the new state. + f.cond.Broadcast() +} diff --git a/transport/internal/bufferfile/fifo_test.go b/transport/internal/bufferfile/fifo_test.go new file mode 100644 index 0000000..553503f --- /dev/null +++ b/transport/internal/bufferfile/fifo_test.go @@ -0,0 +1,621 @@ +package bufferfile + +import ( + "bytes" + "io" + "math/rand" + "sync" + "sync/atomic" + "testing" + "time" +) + +// stat returns information about the current state of the FIFO for testing +// purposes. +func (f *FIFO) stat() (readPos, writePos int64, closed bool) { + f.mu.Lock() + defer f.mu.Unlock() + return f.readPos, f.writePos, f.closed +} + +// TestFIFO_BasicReadWrite tests that data written to a FIFO can be read +// back exactly. This is the fundamental requirement for the FIFO to work +// correctly. +func TestFIFO_BasicReadWrite(t *testing.T) { + // Arrange: Create a new FIFO + fifo, err := NewFIFO() + if err != nil { + t.Fatalf("Failed to create FIFO: %v", err) + } + defer fifo.Close() + + data := []byte("hello world") + buf := make([]byte, len(data)) + + // Act: Write data to FIFO + n, err := fifo.Write(data) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + if n != len(data) { + t.Fatalf("Expected to write %d bytes, wrote %d", len(data), n) + } + + // Act: Read data back from FIFO + n, err = fifo.Read(buf) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + + // Assert: Verify read data matches written data + if n != len(data) { + t.Fatalf("Expected to read %d bytes, read %d", len(data), n) + } + if !bytes.Equal(buf, data) { + t.Fatalf("Read data doesn't match written data: got %q, want %q", buf, data) + } +} + +// TestFIFO_MultipleWrites tests that multiple separate writes are +// concatenated correctly when reading back from the FIFO, preserving the +// order and boundaries. +func TestFIFO_MultipleWrites(t *testing.T) { + // Arrange: Create FIFO and test data + fifo, err := NewFIFO() + if err != nil { + t.Fatalf("Failed to create FIFO: %v", err) + } + defer fifo.Close() + + chunks := [][]byte{ + []byte("chunk1"), + []byte("chunk2"), + []byte("chunk3"), + } + + // Act: Write multiple chunks sequentially + for i, chunk := range chunks { + n, err := fifo.Write(chunk) + if err != nil { + t.Fatalf("Write %d failed: %v", i, err) + } + if n != len(chunk) { + t.Fatalf("Write %d: expected %d bytes, wrote %d", i, len(chunk), n) + } + } + + // Act: Read all data back + expected := bytes.Join(chunks, nil) + buf := make([]byte, len(expected)) + totalRead := 0 + + for totalRead < len(expected) { + n, err := fifo.Read(buf[totalRead:]) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + totalRead += n + } + + // Assert: Verify concatenated data is correct + if !bytes.Equal(buf, expected) { + t.Fatalf("Read data doesn't match expected: got %q, want %q", buf, expected) + } +} + +// TestFIFO_PartialReads tests that data can be read in smaller chunks than +// it was written, ensuring proper read position tracking. +func TestFIFO_PartialReads(t *testing.T) { + fifo, err := NewFIFO() + if err != nil { + t.Fatalf("Failed to create FIFO: %v", err) + } + defer fifo.Close() + + // Write data + data := []byte("0123456789") + _, err = fifo.Write(data) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Read in small chunks + buf := make([]byte, 3) // Smaller than data + var result []byte + + for len(result) < len(data) { + n, err := fifo.Read(buf) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + result = append(result, buf[:n]...) + } + + if !bytes.Equal(result, data) { + t.Fatalf("Partial read result doesn't match: got %q, want %q", result, data) + } +} + +// TestFIFO_ConcurrentReadWrite tests that multiple concurrent writers and +// readers can safely access the FIFO without data corruption or race +// conditions. +func TestFIFO_ConcurrentReadWrite(t *testing.T) { + fifo, err := NewFIFO() + if err != nil { + t.Fatalf("Failed to create FIFO: %v", err) + } + defer fifo.Close() + + const numWriters = 3 + const numChunksPerWriter = 100 + const chunkSize = 100 + + var wg sync.WaitGroup + var writeOrder []int + var writeOrderMu sync.Mutex + + // Start multiple writers + for writerID := 0; writerID < numWriters; writerID++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for i := 0; i < numChunksPerWriter; i++ { + // Create unique data for this writer and chunk + data := make([]byte, chunkSize) + for j := range data { + data[j] = byte((id*1000 + i) % 256) + } + + writeOrderMu.Lock() + writeOrder = append(writeOrder, id*1000+i) + writeOrderMu.Unlock() + + _, err := fifo.Write(data) + if err != nil { + t.Errorf("Writer %d chunk %d failed: %v", id, i, err) + return + } + } + }(writerID) + } + + // Read all data + var readData []byte + totalExpected := numWriters * numChunksPerWriter * chunkSize + buf := make([]byte, 1024) // Read buffer + + readDone := make(chan struct{}) + go func() { + defer close(readDone) + for len(readData) < totalExpected { + n, err := fifo.Read(buf) + if err != nil { + t.Errorf("Read failed: %v", err) + return + } + readData = append(readData, buf[:n]...) + } + }() + + // Wait for all writes to complete + wg.Wait() + + // Wait for all reads to complete + select { + case <-readDone: + // Success + case <-time.After(5 * time.Second): + t.Fatal("Read timed out") + } + + if len(readData) != totalExpected { + t.Fatalf("Expected to read %d bytes, got %d", totalExpected, len(readData)) + } + + t.Logf("Successfully handled %d concurrent writers writing %d total bytes", + numWriters, totalExpected) +} + +// TestFIFO_ReadBlocksUntilData tests that reads block when no data is +// available and unblock immediately when data is written, which is essential +// for the streaming behavior needed by the parallel transport. +func TestFIFO_ReadBlocksUntilData(t *testing.T) { + fifo, err := NewFIFO() + if err != nil { + t.Fatalf("Failed to create FIFO: %v", err) + } + defer fifo.Close() + + buf := make([]byte, 10) + readDone := make(chan struct{}) + var readErr error + + // Start a reader that should block + go func() { + defer close(readDone) + _, readErr = fifo.Read(buf) + }() + + // Ensure reader is blocked + select { + case <-readDone: + t.Fatal("Read should have blocked") + case <-time.After(100 * time.Millisecond): + // Good, read is blocked + } + + // Write data to unblock reader + data := []byte("test") + _, err = fifo.Write(data) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Now read should complete + select { + case <-readDone: + if readErr != nil { + t.Fatalf("Read failed: %v", readErr) + } + case <-time.After(time.Second): + t.Fatal("Read did not complete after write") + } +} + +// TestFIFO_CloseInterruptsRead tests that Close() interrupts blocked +// readers and causes them to return EOF, which is needed for proper cleanup. +func TestFIFO_CloseInterruptsRead(t *testing.T) { + fifo, err := NewFIFO() + if err != nil { + t.Fatalf("Failed to create FIFO: %v", err) + } + + buf := make([]byte, 10) + readDone := make(chan struct{}) + var readN int + var readErr error + + // Start a reader that should block + go func() { + defer close(readDone) + readN, readErr = fifo.Read(buf) + }() + + // Ensure reader is blocked + select { + case <-readDone: + t.Fatal("Read should have blocked") + case <-time.After(100 * time.Millisecond): + // Good, read is blocked + } + + // Close FIFO to interrupt read + err = fifo.Close() + if err != nil { + t.Fatalf("Close failed: %v", err) + } + + // Read should complete with EOF + select { + case <-readDone: + if readErr != io.EOF { + t.Fatalf("Expected EOF after close, got: %v", readErr) + } + if readN != 0 { + t.Fatalf("Expected 0 bytes read after close, got %d", readN) + } + case <-time.After(time.Second): + t.Fatal("Read did not complete after close") + } +} + +// TestFIFO_CloseWithPendingData tests that Close() immediately makes all +// data unavailable, which implements the interruptible FIFO semantics. +func TestFIFO_CloseWithPendingData(t *testing.T) { + fifo, err := NewFIFO() + if err != nil { + t.Fatalf("Failed to create FIFO: %v", err) + } + + // Write some data + data := []byte("pending data") + _, err = fifo.Write(data) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Close FIFO + err = fifo.Close() + if err != nil { + t.Fatalf("Close failed: %v", err) + } + + // After close, reads should return EOF immediately (data is lost) + buf := make([]byte, len(data)) + n, err := fifo.Read(buf) + if err != io.EOF { + t.Fatalf("Expected EOF after close, got: %v", err) + } + if n != 0 { + t.Fatalf("Expected 0 bytes read after close, got %d", n) + } +} + +// TestFIFO_WriteAfterClose tests that writes fail after the FIFO is closed. +func TestFIFO_WriteAfterClose(t *testing.T) { + fifo, err := NewFIFO() + if err != nil { + t.Fatalf("Failed to create FIFO: %v", err) + } + + err = fifo.Close() + if err != nil { + t.Fatalf("Close failed: %v", err) + } + + // Write after close should fail + _, err = fifo.Write([]byte("test")) + if err == nil { + t.Fatal("Expected write after close to fail") + } + + // Even empty writes should fail after close + _, err = fifo.Write(nil) + if err == nil { + t.Fatal("Expected empty write after close to fail") + } +} + +// TestFIFO_WriteAfterCloseWrite tests that writes fail after CloseWrite +// is called. +func TestFIFO_WriteAfterCloseWrite(t *testing.T) { + fifo, err := NewFIFO() + if err != nil { + t.Fatalf("Failed to create FIFO: %v", err) + } + defer fifo.Close() + + fifo.CloseWrite() + + // Write after CloseWrite should fail + _, err = fifo.Write([]byte("test")) + if err == nil { + t.Fatal("Expected write after CloseWrite to fail") + } + + // Even empty writes should fail after CloseWrite + _, err = fifo.Write(nil) + if err == nil { + t.Fatal("Expected empty write after CloseWrite to fail") + } +} + +// TestFIFO_Stat tests the internal stat method used for debugging and +// testing position tracking. +func TestFIFO_Stat(t *testing.T) { + fifo, err := NewFIFO() + if err != nil { + t.Fatalf("Failed to create FIFO: %v", err) + } + defer fifo.Close() + + // Check initial state + readPos, writePos, closed := fifo.stat() + if readPos != 0 || writePos != 0 || closed { + t.Fatalf("Initial state wrong: readPos=%d, writePos=%d, closed=%v", + readPos, writePos, closed) + } + + // Write some data + data := []byte("test data") + _, err = fifo.Write(data) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + + readPos, writePos, closed = fifo.stat() + if readPos != 0 || writePos != int64(len(data)) || closed { + t.Fatalf("After write state wrong: readPos=%d, writePos=%d, closed=%v", + readPos, writePos, closed) + } + + // Read some data + buf := make([]byte, 4) + n, err := fifo.Read(buf) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + + readPos, writePos, closed = fifo.stat() + if readPos != int64(n) || writePos != int64(len(data)) || closed { + t.Fatalf("After read state wrong: readPos=%d, writePos=%d, closed=%v", + readPos, writePos, closed) + } + + // Close and check + fifo.Close() + readPos, writePos, closed = fifo.stat() + if !closed { + t.Fatal("FIFO should be marked as closed") + } +} + +// TestFIFO_StressTest performs concurrent read/write operations to test +// for race conditions and data corruption under heavy load. +func TestFIFO_StressTest(t *testing.T) { + if testing.Short() { + t.Skip("Skipping stress test in short mode") + } + + fifo, err := NewFIFO() + if err != nil { + t.Fatalf("Failed to create FIFO: %v", err) + } + defer fifo.Close() + + const duration = 2 * time.Second + const maxWriteSize = 1024 + const maxReadSize = 512 + + var totalWritten int64 + var totalRead int64 + var wg sync.WaitGroup + + // Start writer goroutine + wg.Add(1) + go func() { + defer wg.Done() + defer fifo.CloseWrite() + + // Signal to readers that no more bytes will arrive once the writer + // finishes so blocked reads can terminate. + start := time.Now() + for time.Since(start) < duration { + size := rand.Intn(maxWriteSize) + 1 + data := make([]byte, size) + rand.Read(data) + + n, err := fifo.Write(data) + if err != nil { + t.Errorf("Write failed: %v", err) + return + } + atomic.AddInt64(&totalWritten, int64(n)) + } + }() + + // Start reader goroutine + wg.Add(1) + go func() { + defer wg.Done() + buf := make([]byte, maxReadSize) + start := time.Now() + + for time.Since(start) < duration+time.Second { + // Give extra time to read. + n, err := fifo.Read(buf) + if err == io.EOF { + break + } + if err != nil { + t.Errorf("Read failed: %v", err) + return + } + atomic.AddInt64(&totalRead, int64(n)) + + // If we've read everything written and writer is done, we're + // done. + if atomic.LoadInt64(&totalRead) >= atomic.LoadInt64(&totalWritten) && + time.Since(start) > duration { + break + } + } + }() + + wg.Wait() + + finalWritten := atomic.LoadInt64(&totalWritten) + finalRead := atomic.LoadInt64(&totalRead) + t.Logf("Stress test completed: wrote %d bytes, read %d bytes", + finalWritten, finalRead) + + if finalRead > finalWritten { + t.Fatalf("Read more than written: read=%d, written=%d", + finalRead, finalWritten) + } +} + +// TestFIFO_EmptyOperations tests that empty reads and writes are handled +// correctly. +func TestFIFO_EmptyOperations(t *testing.T) { + fifo, err := NewFIFO() + if err != nil { + t.Fatalf("Failed to create FIFO: %v", err) + } + defer fifo.Close() + + // Test empty write + n, err := fifo.Write(nil) + if err != nil { + t.Fatalf("Empty write failed: %v", err) + } + if n != 0 { + t.Fatalf("Expected 0 bytes written for empty write, got %d", n) + } + + // Test empty read + n, err = fifo.Read(nil) + if err != nil { + t.Fatalf("Empty read failed: %v", err) + } + if n != 0 { + t.Fatalf("Expected 0 bytes read for empty read, got %d", n) + } +} + +// TestFIFO_MultipleClose tests that calling Close() multiple times is +// safe and doesn't cause errors or panics. +func TestFIFO_MultipleClose(t *testing.T) { + fifo, err := NewFIFO() + if err != nil { + t.Fatalf("Failed to create FIFO: %v", err) + } + + // First close should succeed + err = fifo.Close() + if err != nil { + t.Fatalf("First close failed: %v", err) + } + + // Second close should not panic and should not error + err = fifo.Close() + if err != nil { + t.Fatalf("Second close failed: %v", err) + } +} + +// Benchmark tests. +// BenchmarkFIFO_Write measures the performance of write operations. +func BenchmarkFIFO_Write(b *testing.B) { + fifo, err := NewFIFO() + if err != nil { + b.Fatalf("Failed to create FIFO: %v", err) + } + defer fifo.Close() + + data := make([]byte, 1024) + rand.Read(data) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := fifo.Write(data) + if err != nil { + b.Fatalf("Write failed: %v", err) + } + } +} + +// BenchmarkFIFO_Read measures the performance of read operations. +func BenchmarkFIFO_Read(b *testing.B) { + fifo, err := NewFIFO() + if err != nil { + b.Fatalf("Failed to create FIFO: %v", err) + } + defer fifo.Close() + + // Pre-fill with data + data := make([]byte, 1024) + rand.Read(data) + for i := 0; i < b.N; i++ { + fifo.Write(data) + } + + buf := make([]byte, 1024) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := fifo.Read(buf) + if err != nil { + b.Fatalf("Read failed: %v", err) + } + } +} diff --git a/transport/internal/common/http_utils.go b/transport/internal/common/http_utils.go new file mode 100644 index 0000000..a71186d --- /dev/null +++ b/transport/internal/common/http_utils.go @@ -0,0 +1,110 @@ +// Package common provides shared utilities for HTTP transport implementations. +package common + +import ( + "net/http" + "strconv" + "strings" +) + +// SupportsRange determines whether an HTTP response indicates support for range requests. +func SupportsRange(h http.Header) bool { + ar := strings.ToLower(h.Get("Accept-Ranges")) + for _, part := range strings.Split(ar, ",") { + if strings.TrimSpace(part) == "bytes" { + return true + } + } + return false +} + +// ScrubConditionalHeaders removes conditional headers we do not want to forward +// on range requests, because they can alter semantics or conflict with If-Range logic. +func ScrubConditionalHeaders(h http.Header) { + h.Del("If-None-Match") + h.Del("If-Modified-Since") + h.Del("If-Match") + h.Del("If-Unmodified-Since") + // Range/If-Range headers are set explicitly by the caller. +} + +// IsWeakETag reports whether the ETag is a weak validator (W/"...") which must +// not be used with If-Range per RFC 7232 §2.1. +func IsWeakETag(etag string) bool { + etag = strings.TrimSpace(etag) + return strings.HasPrefix(etag, "W/") || strings.HasPrefix(etag, "w/") +} + +// ParseSingleRange parses a single "Range: bytes=start-end" header. +// It returns (start, end, ok). When end is omitted, end == -1. +// +// Notes: +// - Only absolute-start forms are supported (no suffix ranges "-N"). +// - Multi-range specifications (comma separated) return ok == false. +func ParseSingleRange(h string) (int64, int64, bool) { + if h == "" { + return 0, -1, false + } + h = strings.TrimSpace(h) + if !strings.HasPrefix(strings.ToLower(h), "bytes=") { + return 0, -1, false + } + spec := strings.TrimSpace(h[len("bytes="):]) + if strings.Contains(spec, ",") { + return 0, -1, false + } + parts := strings.SplitN(spec, "-", 2) + if len(parts) != 2 { + return 0, -1, false + } + if parts[0] == "" { + // Suffix form is not supported here. + return 0, -1, false + } + start, err := strconv.ParseInt(strings.TrimSpace(parts[0]), 10, 64) + if err != nil || start < 0 { + return 0, -1, false + } + end := int64(-1) + if strings.TrimSpace(parts[1]) != "" { + e, err := strconv.ParseInt(strings.TrimSpace(parts[1]), 10, 64) + if err != nil || e < start { + return 0, -1, false + } + end = e + } + return start, end, true +} + +// ParseContentRange parses "Content-Range: bytes start-end/total". It +// returns (start, end, total, ok). When total is unknown, total == -1. +func ParseContentRange(h string) (int64, int64, int64, bool) { + if h == "" { + return 0, -1, -1, false + } + h = strings.ToLower(strings.TrimSpace(h)) + if !strings.HasPrefix(h, "bytes ") { + return 0, -1, -1, false + } + body := strings.TrimSpace(h[len("bytes "):]) + seTotal := strings.SplitN(body, "/", 2) + if len(seTotal) != 2 { + return 0, -1, -1, false + } + se := strings.SplitN(strings.TrimSpace(seTotal[0]), "-", 2) + if len(se) != 2 { + return 0, -1, -1, false + } + start, err1 := strconv.ParseInt(strings.TrimSpace(se[0]), 10, 64) + end, err2 := strconv.ParseInt(strings.TrimSpace(se[1]), 10, 64) + totalStr := strings.TrimSpace(seTotal[1]) + var total int64 = -1 + var err3 error + if totalStr != "*" { + total, err3 = strconv.ParseInt(totalStr, 10, 64) + } + if err1 != nil || err2 != nil || (err3 != nil && totalStr != "*") { + return 0, -1, -1, false + } + return start, end, total, true +} diff --git a/transport/internal/common/http_utils_test.go b/transport/internal/common/http_utils_test.go new file mode 100644 index 0000000..1becd3b --- /dev/null +++ b/transport/internal/common/http_utils_test.go @@ -0,0 +1,195 @@ +package common + +import ( + "net/http" + "testing" +) + +// TestParseSingleRange exercises valid and invalid single-range specs. +func TestParseSingleRange(t *testing.T) { + cases := []struct { + in string + start, end int64 + ok bool + }{ + {"", 0, -1, false}, + {"bytes=0-99", 0, 99, true}, + {"bytes=0-", 0, -1, true}, + {"bytes=5-5", 5, 5, true}, + {"BYTES=7-9", 7, 9, true}, + // End before start. + {"bytes=10-5", 0, -1, false}, + // Suffix not supported. + {"bytes=-100", 0, -1, false}, + {"items=0-10", 0, -1, false}, + // Multi-range unsupported. + {"bytes=0-1,3-5", 0, -1, false}, + } + for _, tc := range cases { + start, end, ok := ParseSingleRange(tc.in) + if start != tc.start || end != tc.end || ok != tc.ok { + t.Errorf("ParseSingleRange(%q) = (%d,%d,%v), want (%d,%d,%v)", tc.in, start, end, ok, tc.start, tc.end, tc.ok) + } + } +} + +// TestParseContentRange exercises valid and invalid Content-Range headers. +func TestParseContentRange(t *testing.T) { + cases := []struct { + in string + start, end int64 + total int64 + ok bool + }{ + {"", 0, -1, -1, false}, + {"bytes 0-99/200", 0, 99, 200, true}, + {"BYTES 1-1/2", 1, 1, 2, true}, + {"bytes 0-0/*", 0, 0, -1, true}, + {"items 0-1/2", 0, -1, -1, false}, + {"bytes 0-99/abc", 0, -1, -1, false}, + // Parser accepts; semantic check happens elsewhere. + {"bytes 5-4/10", 5, 4, 10, true}, + } + for _, tc := range cases { + start, end, total, ok := ParseContentRange(tc.in) + if start != tc.start || end != tc.end || total != tc.total || ok != tc.ok { + t.Errorf("ParseContentRange(%q) = (%d,%d,%d,%v), want (%d,%d,%d,%v)", tc.in, start, end, total, ok, tc.start, tc.end, tc.total, tc.ok) + } + } +} + +// TestSupportsRange tests the Accept-Ranges header parsing. +func TestSupportsRange(t *testing.T) { + cases := []struct { + name string + header http.Header + expected bool + }{ + { + name: "no header", + header: http.Header{}, + expected: false, + }, + { + name: "bytes supported", + header: http.Header{"Accept-Ranges": []string{"bytes"}}, + expected: true, + }, + { + name: "bytes with mixed case", + header: http.Header{"Accept-Ranges": []string{"BYTES"}}, + expected: true, + }, + { + name: "bytes with other values", + header: http.Header{"Accept-Ranges": []string{"none, bytes"}}, + expected: true, + }, + { + name: "none only", + header: http.Header{"Accept-Ranges": []string{"none"}}, + expected: false, + }, + { + name: "other unit", + header: http.Header{"Accept-Ranges": []string{"items"}}, + expected: false, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + result := SupportsRange(tc.header) + if result != tc.expected { + t.Errorf("SupportsRange() = %v, want %v", result, tc.expected) + } + }) + } +} + +// TestIsWeakETag tests weak ETag detection. +func TestIsWeakETag(t *testing.T) { + cases := []struct { + name string + etag string + expected bool + }{ + { + name: "strong etag", + etag: `"abc123"`, + expected: false, + }, + { + name: "weak etag uppercase W", + etag: `W/"abc123"`, + expected: true, + }, + { + name: "weak etag lowercase w", + etag: `w/"abc123"`, + expected: true, + }, + { + name: "empty", + etag: "", + expected: false, + }, + { + name: "with spaces", + etag: ` W/"abc123" `, + expected: true, + }, + { + name: "malformed but starts with W", + etag: "W/malformed", + expected: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + result := IsWeakETag(tc.etag) + if result != tc.expected { + t.Errorf("IsWeakETag(%q) = %v, want %v", tc.etag, result, tc.expected) + } + }) + } +} + +// TestScrubConditionalHeaders tests conditional header removal. +func TestScrubConditionalHeaders(t *testing.T) { + // Set up test headers with both conditional and non-conditional headers. + headers := http.Header{ + "If-None-Match": []string{`"etag1"`}, + "If-Modified-Since": []string{"Wed, 21 Oct 2015 07:28:00 GMT"}, + "If-Match": []string{`"etag2"`}, + "If-Unmodified-Since": []string{"Thu, 22 Oct 2015 07:28:00 GMT"}, + "Range": []string{"bytes=0-99"}, + "If-Range": []string{`"etag3"`}, + "Authorization": []string{"Bearer token"}, + } + + // Scrub the conditional headers. + ScrubConditionalHeaders(headers) + + // Verify conditional headers are removed. + conditionalHeaders := []string{ + "If-None-Match", + "If-Modified-Since", + "If-Match", + "If-Unmodified-Since", + } + for _, header := range conditionalHeaders { + if headers.Get(header) != "" { + t.Errorf("conditional header %s was not scrubbed", header) + } + } + + // Verify other headers are preserved. + preservedHeaders := []string{"Range", "If-Range", "Authorization"} + for _, header := range preservedHeaders { + if headers.Get(header) == "" { + t.Errorf("header %s was incorrectly removed", header) + } + } +} diff --git a/transport/internal/testing/fake_transport.go b/transport/internal/testing/fake_transport.go new file mode 100644 index 0000000..49c69fa --- /dev/null +++ b/transport/internal/testing/fake_transport.go @@ -0,0 +1,340 @@ +// Package testing provides common test utilities for transport packages. +package testing + +import ( + "bytes" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "sync" +) + +// FakeResource represents a resource that can be served by FakeTransport. +type FakeResource struct { + // Data provides random access to the resource content. + Data io.ReaderAt + // Length is the total number of bytes in the resource content. + Length int64 + // SupportsRange indicates if this resource supports byte ranges. + SupportsRange bool + // ETag is the ETag header value (optional). + ETag string + // LastModified is the Last-Modified header value (optional). + LastModified string + // ContentType is the Content-Type header value (optional). + ContentType string + // Headers are additional headers to include in responses. + Headers http.Header +} + +// FakeTransport is a test http.RoundTripper that serves fake resources. +type FakeTransport struct { + mu sync.Mutex + resources map[string]*FakeResource + requests []http.Request + // FailAfter causes the transport to fail after serving this many bytes + // on a request (for simulating connection failures). + failAfter map[string]int + // failCount tracks how many times we've failed for each URL. + failCount map[string]int + // RequestHook is called for each request if set. + RequestHook func(*http.Request) + // ResponseHook is called for each response if set. + ResponseHook func(*http.Response) +} + +// NewFakeTransport creates a new FakeTransport. +func NewFakeTransport() *FakeTransport { + return &FakeTransport{ + resources: make(map[string]*FakeResource), + failAfter: make(map[string]int), + failCount: make(map[string]int), + } +} + +// Add adds a resource to the fake transport. +func (ft *FakeTransport) Add(url string, resource *FakeResource) { + ft.mu.Lock() + defer ft.mu.Unlock() + ft.resources[url] = resource +} + +// AddSimple adds a simple resource with the provided reader and length. +func (ft *FakeTransport) AddSimple(url string, data io.ReaderAt, length int64, supportsRange bool) { + ft.Add(url, &FakeResource{ + Data: data, + Length: length, + SupportsRange: supportsRange, + }) +} + +// SetFailAfter configures the transport to fail after serving n bytes for +// the given URL. +func (ft *FakeTransport) SetFailAfter(url string, n int) { + ft.mu.Lock() + defer ft.mu.Unlock() + ft.failAfter[url] = n +} + +// GetRequests returns a copy of all requests made to this transport. +func (ft *FakeTransport) GetRequests() []http.Request { + ft.mu.Lock() + defer ft.mu.Unlock() + reqs := make([]http.Request, len(ft.requests)) + copy(reqs, ft.requests) + return reqs +} + +// GetRequestHeaders returns the headers from all requests for a given URL. +func (ft *FakeTransport) GetRequestHeaders(url string) []http.Header { + ft.mu.Lock() + defer ft.mu.Unlock() + + var headers []http.Header + for _, req := range ft.requests { + if req.URL.String() == url { + h := make(http.Header) + for k, v := range req.Header { + h[k] = append([]string(nil), v...) + } + headers = append(headers, h) + } + } + return headers +} + +// RoundTrip implements http.RoundTripper. +func (ft *FakeTransport) RoundTrip(req *http.Request) (*http.Response, error) { + ft.mu.Lock() + // Store request + reqCopy := *req + if req.Header != nil { + reqCopy.Header = req.Header.Clone() + } + ft.requests = append(ft.requests, reqCopy) + + // Get resource + resource, exists := ft.resources[req.URL.String()] + failAfter := ft.failAfter[req.URL.String()] + ft.mu.Unlock() + + if ft.RequestHook != nil { + ft.RequestHook(req) + } + + if !exists { + return &http.Response{ + StatusCode: http.StatusNotFound, + Status: "404 Not Found", + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader(nil)), + Request: req, + }, nil + } + + // Handle HEAD request + if req.Method == http.MethodHead { + resp := ft.createResponse(req, resource, nil, http.StatusOK) + if ft.ResponseHook != nil { + ft.ResponseHook(resp) + } + return resp, nil + } + + // Handle Range request + if rangeHeader := req.Header.Get("Range"); rangeHeader != "" && resource.SupportsRange { + return ft.handleRangeRequest(req, resource, rangeHeader, failAfter) + } + + // Regular GET request + var body io.ReadCloser + if failAfter > 0 && ft.getFailCount(req.URL.String()) == 0 { + // First request - fail after specified bytes + body = NewFlakyReader(resource.Data, resource.Length, failAfter) + ft.incrementFailCount(req.URL.String()) + } else { + // Subsequent request or no failure configured + body = io.NopCloser(io.NewSectionReader(resource.Data, 0, resource.Length)) + } + + resp := ft.createResponse(req, resource, body, http.StatusOK) + if ft.ResponseHook != nil { + ft.ResponseHook(resp) + } + return resp, nil +} + +// handleRangeRequest serves a single byte range request for a resource. +// It validates the Range and If-Range headers and returns either 206 with the +// requested slice, or 200 with the full resource if validation fails. +// Multi-range specifications are not supported and result in 400. +func (ft *FakeTransport) handleRangeRequest(req *http.Request, resource *FakeResource, rangeHeader string, failAfter int) (*http.Response, error) { + // Parse range header (simplified - only handles single ranges) + if !strings.HasPrefix(rangeHeader, "bytes=") { + return ft.createErrorResponse(req, http.StatusBadRequest), nil + } + + rangeSpec := strings.TrimPrefix(rangeHeader, "bytes=") + parts := strings.Split(rangeSpec, "-") + if len(parts) != 2 { + return ft.createErrorResponse(req, http.StatusBadRequest), nil + } + + var start, end int64 + var err error + + if parts[0] != "" { + start, err = strconv.ParseInt(parts[0], 10, 64) + if err != nil { + return ft.createErrorResponse(req, http.StatusBadRequest), nil + } + } + + if parts[1] != "" { + end, err = strconv.ParseInt(parts[1], 10, 64) + if err != nil { + return ft.createErrorResponse(req, http.StatusBadRequest), nil + } + } else { + end = resource.Length - 1 + } + + // Validate range + if start < 0 || end >= resource.Length || start > end { + resp := ft.createErrorResponse(req, http.StatusRequestedRangeNotSatisfiable) + resp.Header.Set("Content-Range", fmt.Sprintf("bytes */%d", resource.Length)) + if ft.ResponseHook != nil { + ft.ResponseHook(resp) + } + return resp, nil + } + + // Check If-Range + if ifRange := req.Header.Get("If-Range"); ifRange != "" { + // Check if If-Range matches either ETag or Last-Modified + matches := false + + // Only match strong ETags for If-Range + if resource.ETag != "" && !strings.HasPrefix(resource.ETag, "W/") { + if ifRange == resource.ETag { + matches = true + } + } + + // Also check Last-Modified + if !matches && resource.LastModified != "" { + if ifRange == resource.LastModified { + matches = true + } + } + + if !matches { + // Validator doesn't match - return full content + body := NewFlakyReader(resource.Data, resource.Length, failAfter) + resp := ft.createResponse(req, resource, body, http.StatusOK) + if ft.ResponseHook != nil { + ft.ResponseHook(resp) + } + return resp, nil + } + } + + // Serve range + body := io.NopCloser(io.NewSectionReader(resource.Data, start, end-start+1)) + + resp := ft.createResponse(req, resource, body, http.StatusPartialContent) + resp.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, resource.Length)) + resp.ContentLength = end - start + 1 + + if ft.ResponseHook != nil { + ft.ResponseHook(resp) + } + return resp, nil +} + +// createResponse builds a basic http.Response for the given resource and +// status code, copying standard headers and any optional metadata. +func (ft *FakeTransport) createResponse(req *http.Request, resource *FakeResource, body io.ReadCloser, statusCode int) *http.Response { + if body == nil { + body = io.NopCloser(bytes.NewReader(nil)) + } + + resp := &http.Response{ + StatusCode: statusCode, + Status: http.StatusText(statusCode), + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Body: body, + Request: req, + } + + // Set standard headers + if resource.SupportsRange { + resp.Header.Set("Accept-Ranges", "bytes") + } + + if resource.ETag != "" { + resp.Header.Set("ETag", resource.ETag) + } + + if resource.LastModified != "" { + resp.Header.Set("Last-Modified", resource.LastModified) + } + + if resource.ContentType != "" { + resp.Header.Set("Content-Type", resource.ContentType) + } + + // Copy additional headers + if resource.Headers != nil { + for k, v := range resource.Headers { + resp.Header[k] = v + } + } + + // Set Content-Length + if statusCode == http.StatusOK { + resp.ContentLength = resource.Length + resp.Header.Set("Content-Length", strconv.FormatInt(resource.Length, 10)) + } + + return resp +} + +// createErrorResponse constructs a minimal error response with the provided +// status code and an empty body. +func (ft *FakeTransport) createErrorResponse(req *http.Request, statusCode int) *http.Response { + return &http.Response{ + StatusCode: statusCode, + Status: http.StatusText(statusCode), + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader(nil)), + Request: req, + } +} + +// getFailCount returns how many failures have been injected for the URL so +// far. It is safe for concurrent use. +func (ft *FakeTransport) getFailCount(url string) int { + ft.mu.Lock() + defer ft.mu.Unlock() + return ft.failCount[url] +} + +// incrementFailCount increments the injected failure counter for the URL. +// It is safe for concurrent use. +func (ft *FakeTransport) incrementFailCount(url string) { + ft.mu.Lock() + defer ft.mu.Unlock() + ft.failCount[url]++ +} diff --git a/transport/internal/testing/flaky_reader.go b/transport/internal/testing/flaky_reader.go new file mode 100644 index 0000000..707d1ae --- /dev/null +++ b/transport/internal/testing/flaky_reader.go @@ -0,0 +1,244 @@ +package testing + +import ( + "errors" + "io" + "sync" +) + +// ErrFlakyFailure is returned when FlakyReader simulates a failure. +var ErrFlakyFailure = errors.New("simulated read failure") + +// FlakyReader simulates a reader that fails after a certain number of +// bytes. +type FlakyReader struct { + // data holds the content to be read through random access reads. + data io.ReaderAt + // length is the total number of readable bytes. + length int64 + // failAfter is the byte position after which reads should fail. + failAfter int64 + // pos is the current read position. + pos int64 + // failed indicates if the reader has already failed. + failed bool + // closed indicates if the reader has been closed. + closed bool + // mu protects all fields from concurrent access. + mu sync.Mutex +} + +// NewFlakyReader creates a FlakyReader that fails after reading failAfter +// bytes. If failAfter is 0 or negative, it never fails. +func NewFlakyReader(data io.ReaderAt, length int64, failAfter int) *FlakyReader { + return &FlakyReader{ + data: data, + length: length, + failAfter: int64(failAfter), + } +} + +// Read implements io.Reader. +func (fr *FlakyReader) Read(p []byte) (int, error) { + fr.mu.Lock() + defer fr.mu.Unlock() + + if fr.closed { + return 0, errors.New("read from closed reader") + } + + if fr.failed { + return 0, ErrFlakyFailure + } + + if fr.pos >= fr.length { + return 0, io.EOF + } + + // Calculate how much we can read. + remaining := fr.length - fr.pos + toRead := int64(len(p)) + if toRead > remaining { + toRead = remaining + } + + // Check if we should fail. + if fr.failAfter > 0 && fr.pos+toRead > fr.failAfter { + toRead = fr.failAfter - fr.pos + if toRead <= 0 { + fr.failed = true + return 0, ErrFlakyFailure + } + } + + if toRead == 0 { + return 0, nil + } + + buf := p[:toRead] + n, err := fr.data.ReadAt(buf, fr.pos) + fr.pos += int64(n) + + if err != nil && err != io.EOF { + return n, err + } + + if fr.failAfter > 0 && fr.pos >= fr.failAfter && fr.pos < fr.length { + fr.failed = true + if n == 0 { + return 0, ErrFlakyFailure + } + } + + if fr.pos >= fr.length { + return n, io.EOF + } + + if err == io.EOF { + return n, io.EOF + } + + return n, nil +} + +// Close implements io.Closer. +func (fr *FlakyReader) Close() error { + fr.mu.Lock() + defer fr.mu.Unlock() + fr.closed = true + return nil +} + +// Reset resets the reader to start from the beginning. +func (fr *FlakyReader) Reset() { + fr.mu.Lock() + defer fr.mu.Unlock() + fr.pos = 0 + fr.failed = false + fr.closed = false +} + +// Position returns the current read position. +func (fr *FlakyReader) Position() int { + fr.mu.Lock() + defer fr.mu.Unlock() + return int(fr.pos) +} + +// HasFailed returns true if the reader has simulated a failure. +func (fr *FlakyReader) HasFailed() bool { + fr.mu.Lock() + defer fr.mu.Unlock() + return fr.failed +} + +// MultiFailReader simulates multiple failures at different points. +type MultiFailReader struct { + // data holds the content to be read through random access reads. + data io.ReaderAt + // length is the total number of readable bytes. + length int64 + // failurePoints are the byte positions where failures should occur. + failurePoints []int + // failureCount tracks how many failures have been simulated. + failureCount int + // pos is the current read position. + pos int64 + // closed indicates if the reader has been closed. + closed bool + // mu protects all fields from concurrent access. + mu sync.Mutex +} + +// NewMultiFailReader creates a reader that fails at specified byte +// positions. +func NewMultiFailReader(data io.ReaderAt, length int64, failurePoints []int) *MultiFailReader { + return &MultiFailReader{ + data: data, + length: length, + failurePoints: failurePoints, + } +} + +// Read implements io.Reader. +func (mfr *MultiFailReader) Read(p []byte) (int, error) { + mfr.mu.Lock() + defer mfr.mu.Unlock() + + if mfr.closed { + return 0, errors.New("read from closed reader") + } + + if mfr.pos >= mfr.length { + return 0, io.EOF + } + + // Check if we're at a failure point. + for i, point := range mfr.failurePoints { + if i < mfr.failureCount { + continue // Already failed here. + } + if mfr.pos == int64(point) { + mfr.failureCount++ + return 0, ErrFlakyFailure + } + } + + // Calculate how much to read. + remaining := mfr.length - mfr.pos + toRead := int64(len(p)) + if toRead > remaining { + toRead = remaining + } + + // Check if we would cross a failure point. + for i, point := range mfr.failurePoints { + if i < mfr.failureCount { + continue // Skip already used failure points. + } + if mfr.pos < int64(point) && mfr.pos+toRead > int64(point) { + toRead = int64(point) - mfr.pos + break + } + } + + // Copy data. + if toRead == 0 { + return 0, nil + } + + buf := p[:toRead] + n, err := mfr.data.ReadAt(buf, mfr.pos) + mfr.pos += int64(n) + + if err != nil && err != io.EOF { + return n, err + } + + if mfr.pos >= mfr.length { + return n, io.EOF + } + + if err == io.EOF { + return n, io.EOF + } + + return n, nil +} + +// Close implements io.Closer. +func (mfr *MultiFailReader) Close() error { + mfr.mu.Lock() + defer mfr.mu.Unlock() + mfr.closed = true + return nil +} + +// Reset resets the reader to the beginning and clears failure state. +func (mfr *MultiFailReader) Reset() { + mfr.mu.Lock() + defer mfr.mu.Unlock() + mfr.pos = 0 + mfr.failureCount = 0 + mfr.closed = false +} diff --git a/transport/internal/testing/helpers.go b/transport/internal/testing/helpers.go new file mode 100644 index 0000000..e3aa2dd --- /dev/null +++ b/transport/internal/testing/helpers.go @@ -0,0 +1,198 @@ +package testing + +import ( + "bytes" + "crypto/rand" + "fmt" + "io" + "testing" +) + +// GenerateTestData generates deterministic test data of the specified size. +func GenerateTestData(size int) []byte { + data := make([]byte, size) + for i := range data { + data[i] = byte(i % 256) + } + return data +} + +// GenerateRandomData generates random test data of the specified size. +func GenerateRandomData(size int) []byte { + data := make([]byte, size) + if _, err := rand.Read(data); err != nil { + panic(fmt.Sprintf("failed to generate random data: %v", err)) + } + return data +} + +// AssertDataEquals checks if two byte slices are equal. +func AssertDataEquals(t *testing.T, got, want []byte) { + t.Helper() + if !bytes.Equal(got, want) { + t.Errorf("data mismatch: got %d bytes, want %d bytes", len(got), len(want)) + if len(got) == len(want) { + // Find first difference. + for i := range got { + if got[i] != want[i] { + t.Errorf( + "first difference at byte %d: got %02x, want %02x", + i, got[i], want[i]) + break + } + } + } + } +} + +// ReadAll reads all data from a reader and returns it. +func ReadAll(t *testing.T, r io.Reader) []byte { + t.Helper() + data, err := io.ReadAll(r) + if err != nil { + t.Fatalf("failed to read all data: %v", err) + } + return data +} + +// ReadAllWithError reads all data from a reader and returns both data and +// error. +func ReadAllWithError(r io.Reader) ([]byte, error) { + return io.ReadAll(r) +} + +// MustRead reads exactly n bytes from a reader or fails the test. +func MustRead(t *testing.T, r io.Reader, n int) []byte { + t.Helper() + buf := make([]byte, n) + nn, err := io.ReadFull(r, buf) + if err != nil { + t.Fatalf( + "failed to read %d bytes: got %d, err: %v", n, nn, err) + } + return buf +} + +// AssertHeaderEquals checks if a header has the expected value. +func AssertHeaderEquals(t *testing.T, headers map[string][]string, key, want string) { + t.Helper() + values, ok := headers[key] + if !ok || len(values) == 0 { + if want != "" { + t.Errorf("header %q not found, want %q", key, want) + } + return + } + if values[0] != want { + t.Errorf("header %q = %q, want %q", key, values[0], want) + } +} + +// AssertHeaderPresent checks if a header is present. +func AssertHeaderPresent(t *testing.T, headers map[string][]string, key string) { + t.Helper() + if _, ok := headers[key]; !ok { + t.Errorf("header %q not found", key) + } +} + +// AssertHeaderAbsent checks if a header is absent. +func AssertHeaderAbsent(t *testing.T, headers map[string][]string, key string) { + t.Helper() + if _, ok := headers[key]; ok { + t.Errorf("header %q found, want absent", key) + } +} + +// ChunkData splits data into n chunks of approximately equal size. +func ChunkData(data []byte, n int) [][]byte { + if n <= 0 { + return nil + } + if n == 1 { + return [][]byte{data} + } + + chunkSize := len(data) / n + remainder := len(data) % n + + chunks := make([][]byte, n) + offset := 0 + + for i := 0; i < n; i++ { + size := chunkSize + if i == n-1 { + size += remainder + } + chunks[i] = data[offset : offset+size] + offset += size + } + + return chunks +} + +// ConcatChunks concatenates multiple byte slices into one. +func ConcatChunks(chunks [][]byte) []byte { + var total int + for _, chunk := range chunks { + total += len(chunk) + } + + result := make([]byte, 0, total) + for _, chunk := range chunks { + result = append(result, chunk...) + } + + return result +} + +// ByteRange represents a byte range. +type ByteRange struct { + // Start is the starting byte position (inclusive). + Start int64 + // End is the ending byte position (inclusive). + End int64 +} + +// CalculateByteRanges calculates byte ranges for splitting a file of given +// size into n parts. +func CalculateByteRanges(totalSize int64, n int) []ByteRange { + if n <= 0 || totalSize <= 0 { + return nil + } + + ranges := make([]ByteRange, n) + chunkSize := totalSize / int64(n) + remainder := totalSize % int64(n) + + var start int64 + for i := 0; i < n; i++ { + size := chunkSize + if i == n-1 { + size += remainder + } + ranges[i] = ByteRange{ + Start: start, + End: start + size - 1, + } + start += size + } + + return ranges +} + +// AssertNoError fails the test if err is not nil. +func AssertNoError(t *testing.T, err error, msg string) { + t.Helper() + if err != nil { + t.Fatalf("%s: %v", msg, err) + } +} + +// AssertError fails the test if err is nil. +func AssertError(t *testing.T, err error, msg string) { + t.Helper() + if err == nil { + t.Fatalf("%s: expected error, got nil", msg) + } +} diff --git a/transport/internal/testing/testing_test.go b/transport/internal/testing/testing_test.go new file mode 100644 index 0000000..c68db49 --- /dev/null +++ b/transport/internal/testing/testing_test.go @@ -0,0 +1,137 @@ +package testing + +import ( + "bytes" + "io" + "net/http" + "testing" +) + +// TestFakeTransport_Basic tests the basic functionality of FakeTransport. +func TestFakeTransport_Basic(t *testing.T) { + ft := NewFakeTransport() + + // Add a simple resource. + data := []byte("Hello, World!") + ft.AddSimple("http://example.com/test", bytes.NewReader(data), int64(len(data)), true) + + // Create a request. + req, err := http.NewRequest("GET", "http://example.com/test", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + // Perform the request. + resp, err := ft.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip failed: %v", err) + } + defer resp.Body.Close() + + // Read the response. + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read response: %v", err) + } + + // Check the data. + if !bytes.Equal(got, data) { + t.Errorf("Response data mismatch: got %q, want %q", got, data) + } +} + +// TestFlakyReader_FailsAfterN tests that FlakyReader fails after reading +// a specified number of bytes. +func TestFlakyReader_FailsAfterN(t *testing.T) { + data := []byte("Hello, World!") + fr := NewFlakyReader(bytes.NewReader(data), int64(len(data)), 5) + + // Read first 5 bytes. + buf := make([]byte, 5) + n, err := fr.Read(buf) + if err != nil { + t.Fatalf("First read failed: %v", err) + } + if n != 5 { + t.Fatalf("Expected to read 5 bytes, got %d", n) + } + if string(buf) != "Hello" { + t.Errorf("Expected 'Hello', got %q", string(buf)) + } + + // Next read should fail. + _, err = fr.Read(buf) + if err != ErrFlakyFailure { + t.Errorf("Expected ErrFlakyFailure, got %v", err) + } +} + +// TestHelpers_GenerateTestData tests the deterministic test data generator. +func TestHelpers_GenerateTestData(t *testing.T) { + data := GenerateTestData(256) + + if len(data) != 256 { + t.Errorf("Expected 256 bytes, got %d", len(data)) + } + + // Check deterministic pattern. + for i := 0; i < 256; i++ { + if data[i] != byte(i%256) { + t.Errorf("Byte %d: expected %d, got %d", i, i%256, data[i]) + } + } +} + +// TestHelpers_ChunkData tests the data chunking functionality. +func TestHelpers_ChunkData(t *testing.T) { + data := GenerateTestData(100) + chunks := ChunkData(data, 4) + + if len(chunks) != 4 { + t.Fatalf("Expected 4 chunks, got %d", len(chunks)) + } + + // First 3 chunks should be 25 bytes each. + for i := 0; i < 3; i++ { + if len(chunks[i]) != 25 { + t.Errorf("Chunk %d: expected 25 bytes, got %d", i, len(chunks[i])) + } + } + + // Last chunk should be 25 + remainder. + if len(chunks[3]) != 25 { + t.Errorf("Last chunk: expected 25 bytes, got %d", len(chunks[3])) + } + + // Concatenate and verify. + combined := ConcatChunks(chunks) + if !bytes.Equal(combined, data) { + t.Error("Concatenated chunks don't match original data") + } +} + +// TestHelpers_ByteRanges tests byte range calculation for parallel +// downloads. +func TestHelpers_ByteRanges(t *testing.T) { + ranges := CalculateByteRanges(100, 4) + + if len(ranges) != 4 { + t.Fatalf("Expected 4 ranges, got %d", len(ranges)) + } + + expectedRanges := []ByteRange{ + {Start: 0, End: 24}, + {Start: 25, End: 49}, + {Start: 50, End: 74}, + {Start: 75, End: 99}, + } + + for i, r := range ranges { + if r.Start != expectedRanges[i].Start || + r.End != expectedRanges[i].End { + t.Errorf( + "Range %d: got %d-%d, want %d-%d", + i, r.Start, r.End, expectedRanges[i].Start, expectedRanges[i].End) + } + } +} diff --git a/transport/parallel/large_file_test.go b/transport/parallel/large_file_test.go new file mode 100644 index 0000000..ff9e84b --- /dev/null +++ b/transport/parallel/large_file_test.go @@ -0,0 +1,340 @@ +package parallel + +import ( + "bytes" + "crypto/sha256" + "fmt" + "hash" + "io" + "net/http" + "os" + "strconv" + "testing" + + testutil "github.com/docker/model-distribution/transport/internal/testing" +) + +// deterministicDataGenerator generates deterministic data based on position. +// This allows us to generate GB-sized data streams without storing them in +// memory. +type deterministicDataGenerator struct { + position int64 + size int64 +} + +// newDeterministicDataGenerator creates a new deterministic data generator +// with the specified size. +func newDeterministicDataGenerator(size int64) *deterministicDataGenerator { + return &deterministicDataGenerator{ + position: 0, + size: size, + } +} + +// Read implements io.Reader for deterministicDataGenerator. +func (g *deterministicDataGenerator) Read(p []byte) (int, error) { + if g.position >= g.size { + return 0, io.EOF + } + + // Calculate how much we can read. + remaining := g.size - g.position + toRead := int64(len(p)) + if toRead > remaining { + toRead = remaining + } + + // Generate deterministic data based on position. + for i := int64(0); i < toRead; i++ { + pos := g.position + i + // Use a simple but deterministic pattern: position mod 256. + // XOR with some constants to make it more interesting. + p[i] = byte((pos ^ (pos >> 8) ^ (pos >> 16)) % 256) + } + + g.position += toRead + return int(toRead), nil +} + +// ReadAt implements io.ReaderAt for deterministicDataGenerator. +func (g *deterministicDataGenerator) ReadAt(p []byte, off int64) (int, error) { + if off >= g.size { + return 0, io.EOF + } + + remaining := g.size - off + toRead := int64(len(p)) + if toRead > remaining { + toRead = remaining + } + + for i := int64(0); i < toRead; i++ { + pos := off + i + p[i] = byte((pos ^ (pos >> 8) ^ (pos >> 16)) % 256) + } + + if toRead < int64(len(p)) { + return int(toRead), io.EOF + } + + return int(toRead), nil +} + +// addLargeFileResource registers a deterministic large file with the fake +// transport. The resource shares behavior with the previous httptest server +// implementation, including range support and metadata headers. +func addLargeFileResource(ft *testutil.FakeTransport, url string, size int64) { + ft.Add(url, &testutil.FakeResource{ + Data: newDeterministicDataGenerator(size), + Length: size, + SupportsRange: true, + ETag: fmt.Sprintf(`"test-file-%d"`, size), + ContentType: "application/octet-stream", + }) +} + +// hashingReader wraps an io.Reader and computes SHA-256 while reading. +type hashingReader struct { + reader io.Reader + hasher hash.Hash + bytesRead int64 +} + +// newHashingReader creates a new hashing reader that computes SHA-256 +// hash while reading from the provided reader. +func newHashingReader(r io.Reader) *hashingReader { + return &hashingReader{ + reader: r, + hasher: sha256.New(), + bytesRead: 0, + } +} + +// Read implements io.Reader for hashingReader. +func (hr *hashingReader) Read(p []byte) (int, error) { + n, err := hr.reader.Read(p) + if n > 0 { + hr.hasher.Write(p[:n]) + hr.bytesRead += int64(n) + } + return n, err +} + +// Sum returns the SHA-256 hash of all data read so far. +func (hr *hashingReader) Sum() []byte { + return hr.hasher.Sum(nil) +} + +// BytesRead returns the total number of bytes read. +func (hr *hashingReader) BytesRead() int64 { + return hr.bytesRead +} + +// computeExpectedHash computes the expected SHA-256 hash for a file of +// given size. +func computeExpectedHash(size int64) []byte { + hasher := sha256.New() + gen := newDeterministicDataGenerator(size) + io.Copy(hasher, gen) + return hasher.Sum(nil) +} + +// getTestFileSize returns an appropriate file size for testing based on +// whether we're running under the race detector or other conditions. +// The returned size ensures parallel downloads will still occur (larger than +// typical minimum chunk sizes of 1-10MB). +func getTestFileSize(baseSize int64) int64 { + // Allow environment override for custom testing. + if sizeStr := os.Getenv("TEST_FILE_SIZE"); sizeStr != "" { + if size, err := strconv.ParseInt(sizeStr, 10, 64); err == nil { + return size + } + } + + // Check for race detector or coverage mode. + if testing.CoverMode() != "" || raceEnabled { + // Use ~200MB for "large" (1GB) and ~400MB for "very large" (4GB). + // This is large enough to trigger parallel downloads with typical + // chunk sizes of 4-8MB, but small enough to run quickly. + if baseSize >= 4*1024*1024*1024 { + return 400 * 1024 * 1024 // 400MB instead of 4GB. + } + return 200 * 1024 * 1024 // 200MB instead of 1GB. + } + + return baseSize +} + +// TestLargeFile_ParallelVsSequential tests parallel vs sequential +// download of a large file. The actual file size adapts based on whether +// the race detector is enabled (200MB in race mode, 1GB normally). +func TestLargeFile_ParallelVsSequential(t *testing.T) { + if testing.Short() { + t.Skip("Skipping large file test in short mode") + } + + // Test with large file (1GB normally, 200MB in race/coverage mode). + baseSize := int64(1024 * 1024 * 1024) // 1 GB base size. + size := getTestFileSize(baseSize) + + if size != baseSize { + t.Logf("Running with reduced file size: %d MB (race detector or coverage mode detected)", + size/(1024*1024)) + } + + url := fmt.Sprintf("https://parallel.example/data/%d", size) + + // Prepare fake transport resource metadata once for logging consistency. + resourceETag := fmt.Sprintf(`"test-file-%d"`, size) + + // Compute expected hash. + expectedHash := computeExpectedHash(size) + + t.Run("Sequential", func(t *testing.T) { + transport := testutil.NewFakeTransport() + addLargeFileResource(transport, url, size) + client := &http.Client{Transport: transport} + + resp, err := client.Get(url) + if err != nil { + t.Fatalf("Failed to get %s: %v", url, err) + } + defer resp.Body.Close() + + if resp.Header.Get("ETag") != resourceETag { + t.Errorf("Expected ETag %s, got %s", resourceETag, resp.Header.Get("ETag")) + } + + if resp.ContentLength != size { + t.Errorf("Expected Content-Length %d, got %d", + size, resp.ContentLength) + } + + hashingReader := newHashingReader(resp.Body) + _, err = io.Copy(io.Discard, hashingReader) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + if hashingReader.BytesRead() != size { + t.Errorf("Expected to read %d bytes, actually read %d bytes", + size, hashingReader.BytesRead()) + } + + actualHash := hashingReader.Sum() + if !bytes.Equal(expectedHash, actualHash) { + t.Errorf("Hash mismatch.\nExpected: %x\nActual: %x", + expectedHash, actualHash) + } + }) + + t.Run("Parallel", func(t *testing.T) { + baseTransport := testutil.NewFakeTransport() + addLargeFileResource(baseTransport, url, size) + transport := New( + baseTransport, + WithMaxConcurrentPerHost(map[string]uint{"": 0}), + WithMinChunkSize(4*1024*1024), // 4MB chunks. + WithMaxConcurrentPerRequest(8), + ) + client := &http.Client{Transport: transport} + + resp, err := client.Get(url) + if err != nil { + t.Fatalf("Failed to get %s: %v", url, err) + } + defer resp.Body.Close() + + if resp.Header.Get("ETag") != resourceETag { + t.Errorf("Expected ETag %s, got %s", resourceETag, resp.Header.Get("ETag")) + } + + if resp.ContentLength != size { + t.Errorf("Expected Content-Length %d, got %d", + size, resp.ContentLength) + } + + hashingReader := newHashingReader(resp.Body) + _, err = io.Copy(io.Discard, hashingReader) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + if hashingReader.BytesRead() != size { + t.Errorf("Expected to read %d bytes, actually read %d bytes", + size, hashingReader.BytesRead()) + } + + actualHash := hashingReader.Sum() + if !bytes.Equal(expectedHash, actualHash) { + t.Errorf("Hash mismatch.\nExpected: %x\nActual: %x", + expectedHash, actualHash) + } + }) +} + +// TestVeryLargeFile_ParallelDownload tests parallel download of a very large +// file. The actual file size adapts based on whether the race detector is +// enabled (400MB in race mode, 4GB normally). +func TestVeryLargeFile_ParallelDownload(t *testing.T) { + if testing.Short() { + t.Skip("Skipping very large file test in short mode") + } + + // Test with very large file (4GB normally, 400MB in race/coverage mode). + baseSize := int64(4 * 1024 * 1024 * 1024) // 4 GB base size. + size := getTestFileSize(baseSize) + + if size != baseSize { + t.Logf("Running with reduced file size: %d MB (race detector or coverage mode detected)", + size/(1024*1024)) + } + + url := fmt.Sprintf("https://parallel.example/very-large/%d", size) + + baseTransport := testutil.NewFakeTransport() + addLargeFileResource(baseTransport, url, size) + + // Only test parallel for very large files due to time constraints. + transport := New( + baseTransport, + WithMaxConcurrentPerHost(map[string]uint{"": 0}), + WithMinChunkSize(8*1024*1024), // 8MB chunks. + WithMaxConcurrentPerRequest(16), + ) + client := &http.Client{Transport: transport} + + resp, err := client.Get(url) + if err != nil { + t.Fatalf("Failed to get %s: %v", url, err) + } + defer resp.Body.Close() + + if resp.ContentLength != size { + t.Errorf("Expected Content-Length %d, got %d", + size, resp.ContentLength) + } + + // For 4GB, let's just verify we can read the correct number of bytes. + // Computing the full hash would take too long. + bytesRead := int64(0) + buf := make([]byte, 64*1024) // 64KB buffer. + for { + n, err := resp.Body.Read(buf) + bytesRead += int64(n) + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + } + + if bytesRead != size { + t.Errorf("Expected to read %d bytes, actually read %d bytes", + size, bytesRead) + } + + t.Logf("Successfully read %d bytes (4GB) from parallel download", + bytesRead) +} diff --git a/transport/parallel/race_off.go b/transport/parallel/race_off.go new file mode 100644 index 0000000..cafa39d --- /dev/null +++ b/transport/parallel/race_off.go @@ -0,0 +1,8 @@ +//go:build !race +// +build !race + +package parallel + +// raceEnabled is a compile-time constant indicating whether the race +// detector is enabled. +const raceEnabled = false diff --git a/transport/parallel/race_on.go b/transport/parallel/race_on.go new file mode 100644 index 0000000..aa1bf77 --- /dev/null +++ b/transport/parallel/race_on.go @@ -0,0 +1,8 @@ +//go:build race +// +build race + +package parallel + +// raceEnabled is a compile-time constant indicating whether the race +// detector is enabled. +const raceEnabled = true diff --git a/transport/parallel/transport.go b/transport/parallel/transport.go new file mode 100644 index 0000000..f8b1a55 --- /dev/null +++ b/transport/parallel/transport.go @@ -0,0 +1,700 @@ +// Package parallel provides an http.RoundTripper that transparently +// parallelizes GET requests using concurrent byte-range requests for better +// throughput. +// +// ───────────────────────────── How it works ───────────────────────────── +// - For non-GET requests, the transport passes them through unmodified to +// the underlying transport. +// - For GET requests, it first performs a HEAD request to check if the +// server supports byte ranges and to determine the total response size. +// - If the HEAD request indicates range support and known size, the +// transport generates multiple concurrent GET requests with specific +// byte-range headers. +// - Subranges are written to temporary files and stitched together in a +// custom Response.Body that's transparent to the caller. +// - Per-host and per-request concurrency limits are enforced using +// semaphores. +// +// ───────────────────────────── Notes & caveats ─────────────────────────── +// - Only works with servers that support "Accept-Ranges: bytes" and +// provide Content-Length or Content-Range headers with total size +// information. +// - Content-Encoding (compression) is not compatible with byte ranges, +// so compressed responses fall back to single-threaded behavior. +// - Temporary files are created for each subrange and cleaned up +// automatically. +// - The transport respects per-host concurrency limits to avoid +// overwhelming servers. +package parallel + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "os" + "strconv" + "strings" + "sync" + + "github.com/docker/model-distribution/transport/internal/bufferfile" + "github.com/docker/model-distribution/transport/internal/common" +) + +// Option configures a ParallelTransport. +type Option func(*ParallelTransport) + +// WithMaxConcurrentPerHost sets the maximum concurrent requests per +// hostname. Default concurrency limits are applied if not specified. +func WithMaxConcurrentPerHost(limits map[string]uint) Option { + return func(pt *ParallelTransport) { + pt.maxConcurrentPerHost = make(map[string]uint, len(limits)) + for host, limit := range limits { + pt.maxConcurrentPerHost[host] = limit + } + } +} + +// WithMaxConcurrentPerRequest sets the maximum concurrent subrange +// requests for a single request. Default: 4. +func WithMaxConcurrentPerRequest(n uint) Option { + return func(pt *ParallelTransport) { pt.maxConcurrentPerRequest = n } +} + +// WithMinChunkSize sets the minimum size in bytes for each subrange chunk. +// Requests smaller than this will not be parallelized. Default: 1MB. +func WithMinChunkSize(size int64) Option { + return func(pt *ParallelTransport) { pt.minChunkSize = size } +} + +// WithTempDir sets the directory for temporary files. If empty, +// os.TempDir() is used. +func WithTempDir(dir string) Option { + return func(pt *ParallelTransport) { pt.tempDir = dir } +} + +// ParallelTransport wraps another http.RoundTripper and parallelizes GET +// requests using concurrent byte-range requests when possible. +type ParallelTransport struct { + // base is the underlying RoundTripper actually used to send requests. + base http.RoundTripper + // maxConcurrentPerHost maps canonicalized hostname to maximum + // concurrent requests. A value of 0 means unlimited. The "" entry is + // the default for unspecified hosts. + maxConcurrentPerHost map[string]uint + // maxConcurrentPerRequest is the maximum number of concurrent + // subrange requests for a single request. + maxConcurrentPerRequest uint + // minChunkSize is the minimum size in bytes for parallelization to be + // worthwhile. + minChunkSize int64 + // tempDir is the directory for temporary files. + tempDir string + // semaphores tracks per-host concurrency limits. + semaphores map[string]*semaphore + // semMu protects the semaphores map. + semMu sync.RWMutex +} + +// New returns a ParallelTransport wrapping base. If base is nil, +// http.DefaultTransport is used. Options configure parallelization behavior. +func New(base http.RoundTripper, opts ...Option) *ParallelTransport { + if base == nil { + base = http.DefaultTransport + } + pt := &ParallelTransport{ + base: base, + maxConcurrentPerHost: map[string]uint{"": 4}, // default 4 per host. + maxConcurrentPerRequest: 4, + minChunkSize: 1024 * 1024, // 1MB. + tempDir: os.TempDir(), + semaphores: make(map[string]*semaphore), + } + for _, o := range opts { + o(pt) + } + return pt +} + +// RoundTrip implements http.RoundTripper. It parallelizes GET requests +// when possible, otherwise passes requests through to the underlying +// transport. +func (pt *ParallelTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Non-GET requests pass through unmodified. + if req.Method != http.MethodGet { + return pt.base.RoundTrip(req) + } + // Respect caller-provided Range requests. We do not parallelize when the + // request already specifies a byte range, to preserve exact semantics. + if strings.TrimSpace(req.Header.Get("Range")) != "" { + return pt.base.RoundTrip(req) + } + + // Check if parallelization is possible and worthwhile. + canParallelize, pInfo, err := pt.checkParallelizable(req) + if err != nil { + return nil, err + } + if !canParallelize || + pInfo.totalSize < pt.minChunkSize*int64(pt.maxConcurrentPerRequest) { + // Fall back to single request. + return pt.base.RoundTrip(req) + } + + // Perform parallel download. + return pt.parallelDownload(req, pInfo) +} + +// parallelInfo holds information needed for parallel downloads. +type parallelInfo struct { + // totalSize is the total size of the resource in bytes. + totalSize int64 + // etag is the strong ETag validator from the HEAD response, used for + // If-Range. + etag string + // lastModified is the Last-Modified header value, used as fallback + // validator for If-Range. + lastModified string + // header is a clone of the server headers (from HEAD) used to seed the + // final response headers without an extra GET probe. + header http.Header + // proto/protoMajor/protoMinor reflect the server protocol from the HEAD + // response for constructing the final response. + proto string + protoMajor int + protoMinor int +} + +// checkParallelizable performs a HEAD request to determine if the resource +// supports byte ranges and returns the parallel info if available. +func (pt *ParallelTransport) checkParallelizable(req *http.Request) (bool, *parallelInfo, error) { + // Create HEAD request. + headReq := req.Clone(req.Context()) + headReq.Method = http.MethodHead + headReq.Body = nil + headReq.ContentLength = 0 + // Clone and sanitize headers to avoid conditional responses and implicit + // compression that could skew metadata. + headReq.Header = req.Header.Clone() + common.ScrubConditionalHeaders(headReq.Header) + headReq.Header.Set("Accept-Encoding", "identity") + + // Perform HEAD request. + headResp, err := pt.base.RoundTrip(headReq) + if err != nil { + return false, nil, err + } + defer headResp.Body.Close() + + // Only proceed on 200 OK or 206 Partial Content. Anything else (e.g., + // 304 Not Modified due to missed scrub, redirects, etc.) is treated as + // non-parallelizable for safety. + if headResp.StatusCode != http.StatusOK && + headResp.StatusCode != http.StatusPartialContent { + return false, nil, nil + } + + // Check if range requests are supported. + if !common.SupportsRange(headResp.Header) { + return false, nil, nil + } + + // Check for compression which would interfere with byte ranges. + if headResp.Header.Get("Content-Encoding") != "" { + return false, nil, nil + } + + // Get total content length. + totalSize := headResp.ContentLength + if totalSize <= 0 { + // Try to parse from Content-Range if present (206 response). + if headResp.StatusCode == http.StatusPartialContent { + if _, _, total, ok := common.ParseContentRange( + headResp.Header.Get("Content-Range")); ok && total > 0 { + totalSize = total + } else { + return false, nil, nil + } + } else { + return false, nil, nil + } + } + + if totalSize <= 0 { + return false, nil, nil + } + + // Capture validators for If-Range to ensure consistency across parallel + // requests. + info := ¶llelInfo{ + totalSize: totalSize, + header: headResp.Header.Clone(), + proto: headResp.Proto, + protoMajor: headResp.ProtoMajor, + protoMinor: headResp.ProtoMinor, + } + + if et := headResp.Header.Get("ETag"); et != "" && !common.IsWeakETag(et) { + info.etag = et + } else if lm := headResp.Header.Get("Last-Modified"); lm != "" { + info.lastModified = lm + } + + return true, info, nil +} + +// parallelDownload performs a parallel download by splitting the request +// into multiple concurrent byte-range requests. +func (pt *ParallelTransport) parallelDownload(req *http.Request, pInfo *parallelInfo) (*http.Response, error) { + totalSize := pInfo.totalSize + + // Calculate chunk size and number of chunks. + numChunks := int(pt.maxConcurrentPerRequest) + if totalSize < int64(numChunks)*pt.minChunkSize { + numChunks = int(totalSize / pt.minChunkSize) + if numChunks < 1 { + numChunks = 1 + } + } + + chunkSize := totalSize / int64(numChunks) + remainder := totalSize % int64(numChunks) + + // Get or create semaphore for this host. + sem := pt.getSemaphore(req.URL.Host) + + // Create chunks and temporary files. + chunks := make([]*chunk, numChunks) + var start int64 + for i := 0; i < numChunks; i++ { + size := chunkSize + if i == numChunks-1 { + size += remainder // Last chunk gets the remainder. + } + end := start + size - 1 + + fifo, err := bufferfile.NewFIFOInDir(pt.tempDir) + if err != nil { + // Clean up any created FIFOs. + for j := 0; j < i; j++ { + chunks[j].cleanup() + } + return nil, fmt.Errorf("parallel: failed to create FIFO: %w", err) + } + + chunk := &chunk{ + start: start, + end: end, + fifo: fifo, + state: chunkNotStarted, + } + chunks[i] = chunk + start = end + 1 + } + + // Start downloading chunks concurrently (don't wait for completion). + for i, ch := range chunks { + go func(i int, ch *chunk) { + ch.setSimpleState(chunkDownloading, nil) + if err := pt.downloadChunk(req, ch, sem, pInfo); err != nil { + ch.setSimpleState(chunkFailed, fmt.Errorf("chunk %d: %w", i, err)) + ch.fifo.Close() // Close FIFO on error to interrupt readers. + } else { + ch.setSimpleState(chunkCompleted, nil) + // Close write side to signal no more writes (EOF when all data + // read). + ch.fifo.CloseWrite() + } + }(i, ch) + } + + // Create stitched response. + body := &stitchedBody{ + chunks: chunks, + totalSize: totalSize, + ctx: req.Context(), + } + + // Create response using the header response as template. + resp := &http.Response{ + Status: "200 OK", + StatusCode: http.StatusOK, + Proto: pInfo.proto, + ProtoMajor: pInfo.protoMajor, + ProtoMinor: pInfo.protoMinor, + Header: pInfo.header.Clone(), + Body: body, + ContentLength: totalSize, + Request: req, + } + + // Override headers that we control. + resp.Header.Set("Content-Length", strconv.FormatInt(totalSize, 10)) + resp.Header.Del("Content-Range") // Remove any partial content headers. + + return resp, nil +} + +// downloadChunk downloads a single chunk using a byte-range request. +func (pt *ParallelTransport) downloadChunk(origReq *http.Request, chunk *chunk, sem *semaphore, pInfo *parallelInfo) error { + // Acquire semaphore. + if err := sem.acquire(origReq.Context()); err != nil { + return err + } + defer sem.release() + + // Create range request. + rangeReq := origReq.Clone(origReq.Context()) + rangeReq.Header = origReq.Header.Clone() + rangeReq.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", chunk.start, chunk.end)) + + // Prevent compression which would interfere with byte ranges. + rangeReq.Header.Set("Accept-Encoding", "identity") + + // Add If-Range header for consistency validation. + if pInfo.etag != "" { + rangeReq.Header.Set("If-Range", pInfo.etag) + } else if pInfo.lastModified != "" { + rangeReq.Header.Set("If-Range", pInfo.lastModified) + } + + // Remove conditional headers that could conflict with If-Range. + common.ScrubConditionalHeaders(rangeReq.Header) + + // Perform request. + resp, err := pt.base.RoundTrip(rangeReq) + if err != nil { + return err + } + defer resp.Body.Close() + + // Check for If-Range validation failure (server returns 200 instead of 206). + if resp.StatusCode == http.StatusOK { + return fmt.Errorf( + "server returned 200 to range request, resource may have changed (If-Range validation failed)") + } + + // Verify we got a partial content response. + if resp.StatusCode != http.StatusPartialContent { + return fmt.Errorf( + "expected 206 Partial Content, got %d", resp.StatusCode) + } + + // Verify the range matches what we requested. + if start, end, _, ok := common.ParseContentRange(resp.Header.Get("Content-Range")); ok { + if start != chunk.start || end != chunk.end { + return fmt.Errorf( + "server returned range %d-%d, requested %d-%d", + start, end, chunk.start, chunk.end) + } + } + + // Copy response body to FIFO and verify full chunk length is received. + buf := make([]byte, 32*1024) // 32KB buffer. + var copied int64 + for { + n, err := resp.Body.Read(buf) + if n > 0 { + // Write to FIFO + if _, writeErr := chunk.fifo.Write(buf[:n]); writeErr != nil { + return fmt.Errorf( + "failed to write chunk data: %w", writeErr) + } + copied += int64(n) + } + + if err == io.EOF { + // Validate that we received the complete range we requested. + expected := (chunk.end - chunk.start + 1) + if copied != expected { + return fmt.Errorf( + "short read for chunk: got %d, want %d", copied, expected) + } + break + } + if err != nil { + return fmt.Errorf( + "failed to read chunk data: %w", err) + } + } + + return nil +} + +// getSemaphore returns the semaphore for the given host, creating it if needed. +func (pt *ParallelTransport) getSemaphore(host string) *semaphore { + canonicalHost := canonicalizeHost(host) + + pt.semMu.RLock() + if sem, exists := pt.semaphores[canonicalHost]; exists { + pt.semMu.RUnlock() + return sem + } + pt.semMu.RUnlock() + + pt.semMu.Lock() + defer pt.semMu.Unlock() + + // Double-check after acquiring write lock. + if sem, exists := pt.semaphores[canonicalHost]; exists { + return sem + } + + // Determine limit for this host. + limit := pt.maxConcurrentPerHost[canonicalHost] + if limit == 0 { + // Check default. + if defaultLimit, exists := pt.maxConcurrentPerHost[""]; exists { + limit = defaultLimit + } + } + + sem := newSemaphore(int(limit)) + pt.semaphores[canonicalHost] = sem + return sem +} + +// canonicalizeHost returns a canonical form of the hostname for semaphore lookup. +func canonicalizeHost(host string) string { + // Remove port if present. + if h, _, err := net.SplitHostPort(host); err == nil { + host = h + } + return strings.ToLower(host) +} + +// chunkState represents the current state of a chunk download. +type chunkState int + +const ( + chunkNotStarted chunkState = iota + chunkDownloading + chunkCompleted + chunkFailed +) + +// chunk represents a byte range chunk being downloaded to a temporary file. +type chunk struct { + // start is the inclusive starting byte offset for this chunk. + start int64 + // end is the inclusive ending byte offset for this chunk. + end int64 + // fifo is the FIFO buffer where this chunk's data is stored. + fifo *bufferfile.FIFO + // state tracks the current download state of this chunk. + state chunkState + // err holds any error that occurred during download. + err error + // mu protects state and err fields. + mu sync.Mutex +} + +// close closes the FIFO handle. +func (c *chunk) close() error { + if c.fifo == nil { + return nil + } + return c.fifo.Close() +} + +// cleanup closes and removes the FIFO. +func (c *chunk) cleanup() { + if c.fifo != nil { + // Only close the FIFO. Do not nil the pointer to avoid races with + // in-flight writer goroutines checking or using this handle. + c.fifo.Close() + } +} + +// setSimpleState updates the chunk state. No condition signaling needed since FIFO handles coordination. +func (c *chunk) setSimpleState(state chunkState, err error) { + c.mu.Lock() + defer c.mu.Unlock() + c.state = state + c.err = err +} + +// readAvailable reads up to len(p) bytes from the chunk, blocking until data is available. +// Returns the number of bytes read and any error. Returns io.EOF when chunk is complete +// and all data has been read. +func (c *chunk) readAvailable(p []byte, ctx context.Context) (int, error) { + // Check for context cancellation + select { + case <-ctx.Done(): + return 0, ctx.Err() + default: + } + + // Check if chunk failed first + c.mu.Lock() + if c.state == chunkFailed && c.err != nil { + err := c.err + c.mu.Unlock() + return 0, err + } + c.mu.Unlock() + + // Try to read from FIFO + n, err := c.fifo.Read(p) + + // If we got data, return it + if n > 0 { + return n, nil + } + + // If FIFO is closed or returned EOF, check chunk state + if err == io.EOF { + // If chunk is completed and FIFO EOF, we're truly done + c.mu.Lock() + if c.state == chunkCompleted { + c.mu.Unlock() + return 0, io.EOF + } + c.mu.Unlock() + // If chunk not completed but FIFO EOF, there might be an error + // Fall through to return the EOF + } + + return n, err +} + +// stitchedBody implements io.ReadCloser by reading from multiple chunk files in sequence. +type stitchedBody struct { + // chunks is the ordered list of chunk files to read from. + chunks []*chunk + // totalSize is the expected total number of bytes across all chunks. + totalSize int64 + // currentIdx is the index of the chunk currently being read from. + currentIdx int + // bytesRead is the total number of bytes delivered to callers so far. + bytesRead int64 + // closed indicates whether Close() has been called. + closed bool + // ctx is the request context for cancellation. + ctx context.Context + // mu protects all fields from concurrent access. + mu sync.Mutex +} + +// Read reads data by stitching together chunks in order. +func (sb *stitchedBody) Read(p []byte) (int, error) { + sb.mu.Lock() + defer sb.mu.Unlock() + + if sb.closed { + return 0, errors.New("stitchedBody: read from closed body") + } + + if sb.currentIdx >= len(sb.chunks) { + return 0, io.EOF + } + + totalRead := 0 + for len(p) > 0 && sb.currentIdx < len(sb.chunks) { + ch := sb.chunks[sb.currentIdx] + + // Unlock while reading from chunk (chunk handles its own locking) + sb.mu.Unlock() + + // Read available data from current chunk + n, err := ch.readAvailable(p, sb.ctx) + + // Re-lock to update state + sb.mu.Lock() + + if sb.closed { + return totalRead, errors.New("stitchedBody: read from closed body") + } + + if n > 0 { + totalRead += n + sb.bytesRead += int64(n) + p = p[n:] + } + + if err == io.EOF { + // Current chunk is complete, move to next + sb.currentIdx++ + } else if err != nil { + return totalRead, fmt.Errorf("stitchedBody: chunk %d error: %w", sb.currentIdx, err) + } else if n == 0 { + // No error but no data read - this shouldn't happen with readAvailable + // but handle it to avoid infinite loops + return totalRead, fmt.Errorf("stitchedBody: chunk %d read 0 bytes without error or EOF", sb.currentIdx) + } + } + + if totalRead == 0 && sb.currentIdx >= len(sb.chunks) { + return 0, io.EOF + } + + return totalRead, nil +} + +// Close closes all chunk files and cleans up temporary files. +func (sb *stitchedBody) Close() error { + sb.mu.Lock() + defer sb.mu.Unlock() + + if sb.closed { + return nil + } + sb.closed = true + + var errs []error + for _, ch := range sb.chunks { + if err := ch.close(); err != nil { + errs = append(errs, err) + } + ch.cleanup() + } + + if len(errs) > 0 { + return fmt.Errorf("stitchedBody: close errors: %v", errs) + } + return nil +} + +// semaphore implements a counting semaphore for limiting concurrency. +type semaphore struct { + // ch is the buffered channel used to limit concurrent operations. + // If nil, no limits are enforced (unlimited concurrency). + ch chan struct{} +} + +// newSemaphore creates a new semaphore with the given capacity. +// If capacity is 0 or negative, the semaphore allows unlimited concurrency. +func newSemaphore(capacity int) *semaphore { + if capacity <= 0 { + // Unlimited semaphore - nil channel means no limits. + return &semaphore{} + } + return &semaphore{ + ch: make(chan struct{}, capacity), + } +} + +// acquire acquires a semaphore slot, blocking until one is available or context is canceled. +func (s *semaphore) acquire(ctx context.Context) error { + if s.ch == nil { + // Unlimited semaphore - no need to acquire. + return nil + } + select { + case s.ch <- struct{}{}: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// release releases a semaphore slot. +func (s *semaphore) release() { + if s.ch == nil { + // Unlimited semaphore - no need to release. + return + } + <-s.ch +} diff --git a/transport/parallel/transport_test.go b/transport/parallel/transport_test.go new file mode 100644 index 0000000..f4cc626 --- /dev/null +++ b/transport/parallel/transport_test.go @@ -0,0 +1,847 @@ +package parallel + +import ( + "bytes" + "io" + "net/http" + "sync" + "testing" + "time" + + testutil "github.com/docker/model-distribution/transport/internal/testing" +) + +// TestParallelDownload_Success verifies parallel downloads using +// testutil.FakeTransport. +func TestParallelDownload_Success(t *testing.T) { + url := "https://example.com/large-file" + payload := testutil.GenerateTestData(100000) // 100KB. + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + ETag: `"test-etag"`, + }) + + client := &http.Client{ + Transport: New(ft, WithMaxConcurrentPerRequest(4), WithMinChunkSize(1024)), + } + + resp, err := client.Get(url) + if err != nil { + t.Fatalf("GET: %v", err) + } + defer resp.Body.Close() + + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read: %v", err) + } + + testutil.AssertDataEquals(t, got, payload) + + // Verify parallel requests were made. + reqs := ft.GetRequests() + var headCount, rangeCount, getCount int + for _, req := range reqs { + if req.Method == http.MethodHead { + headCount++ + } else if req.Method == http.MethodGet { + getCount++ + if req.Header.Get("Range") != "" { + rangeCount++ + } + } + t.Logf("Request: %s %s, Range: %s", + req.Method, req.URL, req.Header.Get("Range")) + } + + if headCount != 1 { + t.Errorf("expected 1 HEAD request, got %d", headCount) + } + if rangeCount < 2 { + t.Errorf("expected at least 2 range requests, got %d (total GET: %d)", + rangeCount, getCount) + } +} + +// TestSmallFile_FallsBackToSingle verifies small files aren't parallelized. +func TestSmallFile_FallsBackToSingle(t *testing.T) { + url := "https://example.com/small-file" + payload := []byte("small content") + + ft := testutil.NewFakeTransport() + ft.AddSimple(url, bytes.NewReader(payload), int64(len(payload)), true) + + client := &http.Client{ + Transport: New(ft, WithMaxConcurrentPerRequest(4), WithMinChunkSize(1024)), + } + + resp, err := client.Get(url) + if err != nil { + t.Fatalf("GET: %v", err) + } + defer resp.Body.Close() + + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read: %v", err) + } + + testutil.AssertDataEquals(t, got, payload) + + // Should only have HEAD and single GET. + reqs := ft.GetRequests() + var headCount, rangeCount, fullGetCount int + for _, req := range reqs { + if req.Method == http.MethodHead { + headCount++ + } else if req.Method == http.MethodGet { + if req.Header.Get("Range") != "" { + rangeCount++ + } else { + fullGetCount++ + } + } + } + + if headCount != 1 { + t.Errorf("expected 1 HEAD request, got %d", headCount) + } + if rangeCount != 0 { + t.Errorf("expected 0 range requests, got %d", rangeCount) + } + if fullGetCount != 1 { + t.Errorf("expected 1 full GET request, got %d", fullGetCount) + } +} + +// TestNoRangeSupport_FallsBack tests fallback when server doesn't support +// ranges. +func TestNoRangeSupport_FallsBack(t *testing.T) { + url := "https://example.com/no-range" + payload := testutil.GenerateTestData(100000) + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: false, // No range support. + }) + + client := &http.Client{Transport: New(ft)} + + resp, err := client.Get(url) + if err != nil { + t.Fatalf("GET: %v", err) + } + defer resp.Body.Close() + + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read: %v", err) + } + + testutil.AssertDataEquals(t, got, payload) + + // Should fall back to single request. + reqs := ft.GetRequests() + var rangeCount int + for _, req := range reqs { + if req.Header.Get("Range") != "" { + rangeCount++ + } + } + + if rangeCount != 0 { + t.Errorf("expected no range requests, got %d", rangeCount) + } +} + +// TestContentEncoding_FallsBack tests fallback with Content-Encoding. +func TestContentEncoding_FallsBack(t *testing.T) { + url := "https://example.com/gzip" + payload := testutil.GenerateTestData(100000) + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + Headers: http.Header{ + "Content-Encoding": []string{"gzip"}, + }, + }) + + client := &http.Client{Transport: New(ft)} + + resp, err := client.Get(url) + if err != nil { + t.Fatalf("GET: %v", err) + } + defer resp.Body.Close() + + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read: %v", err) + } + + testutil.AssertDataEquals(t, got, payload) + + // Should fall back due to Content-Encoding. + reqs := ft.GetRequests() + var rangeCount int + for _, req := range reqs { + if req.Header.Get("Range") != "" { + rangeCount++ + } + } + + if rangeCount != 0 { + t.Errorf("expected no range requests due to Content-Encoding, got %d", + rangeCount) + } +} + +// TestETagValidation verifies ETag is used for If-Range validation. +func TestETagValidation(t *testing.T) { + url := "https://example.com/etag-test" + payload := testutil.GenerateTestData(100000) + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + ETag: `"strong-etag"`, + }) + + client := &http.Client{Transport: New(ft)} + + resp, err := client.Get(url) + if err != nil { + t.Fatalf("GET: %v", err) + } + defer resp.Body.Close() + + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read: %v", err) + } + + testutil.AssertDataEquals(t, got, payload) + + // Check If-Range headers. + headers := ft.GetRequestHeaders(url) + for _, h := range headers { + if h.Get("Range") != "" { + if ifRange := h.Get("If-Range"); ifRange != `"strong-etag"` { + t.Errorf("expected If-Range with ETag, got %q", ifRange) + } + } + } +} + +// TestWeakETag_UsesLastModified tests weak ETags trigger Last-Modified usage. +func TestWeakETag_UsesLastModified(t *testing.T) { + url := "https://example.com/weak-etag" + payload := testutil.GenerateTestData(100000) + lastModified := time.Unix(1700000000, 0).UTC().Format(http.TimeFormat) + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + ETag: `W/"weak-etag"`, + LastModified: lastModified, + }) + + client := &http.Client{Transport: New(ft)} + + resp, err := client.Get(url) + if err != nil { + t.Fatalf("GET: %v", err) + } + defer resp.Body.Close() + + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read: %v", err) + } + + testutil.AssertDataEquals(t, got, payload) + + // Check If-Range uses Last-Modified instead of weak ETag. + headers := ft.GetRequestHeaders(url) + for _, h := range headers { + if h.Get("Range") != "" { + ifRange := h.Get("If-Range") + if ifRange != lastModified { + t.Errorf("expected If-Range with Last-Modified, got %q", + ifRange) + } + } + } +} + +// TestConcurrencyLimits verifies per-host concurrency limits. +func TestConcurrencyLimits(t *testing.T) { + url := "https://example.com/large" + payload := testutil.GenerateTestData(500000) // 500KB to ensure parallelization. + + ft := testutil.NewFakeTransport() + ft.AddSimple(url, bytes.NewReader(payload), int64(len(payload)), true) + + // Track concurrent requests. maxConcurrent records the peak concurrent range + // downloads observed while currentConcurrent holds the in-flight count at any + // moment. mu ensures those counters are updated atomically. rangeRequests + // counts how many range downloads we observed. wg waits until every tracked + // range request finishes. rangeStartedCh buffers notifications when a new + // tracked range request begins. releaseCh blocks the request until the test + // releases it. releaseOnce ensures releaseCh is only closed once, even on + // early exits. + var maxConcurrent, currentConcurrent int + var mu sync.Mutex + rangeRequests := 0 + var wg sync.WaitGroup + rangeStartedCh := make(chan struct{}, 8) + releaseCh := make(chan struct{}) + var releaseOnce sync.Once + defer releaseOnce.Do(func() { close(releaseCh) }) + + ft.RequestHook = func(req *http.Request) { + rangeHeader := req.Header.Get("Range") + if rangeHeader != "" && rangeHeader != "bytes=0-0" { + wg.Add(1) + + mu.Lock() + currentConcurrent++ + rangeRequests++ + if currentConcurrent > maxConcurrent { + maxConcurrent = currentConcurrent + } + mu.Unlock() + + // Capture the start of the range request without blocking. + select { + case rangeStartedCh <- struct{}{}: + default: + } + + <-releaseCh + + mu.Lock() + currentConcurrent-- + mu.Unlock() + + wg.Done() + } + t.Logf("Request: %s %s, Range: %s", req.Method, req.URL, rangeHeader) + } + + client := &http.Client{ + Transport: New(ft, + WithMaxConcurrentPerHost(map[string]uint{"example.com": 2}), + WithMaxConcurrentPerRequest(4), + WithMinChunkSize(10000)), // Lower min chunk size to ensure parallelization. + } + + resp, err := client.Get(url) + if err != nil { + t.Fatalf("GET: %v", err) + } + defer resp.Body.Close() + + // Drive the download in a goroutine so range requests can start while the + // test observes concurrency. + readDone := make(chan error, 1) + go func() { + _, err := io.ReadAll(resp.Body) + readDone <- err + }() + + for i := 0; i < 2; i++ { + select { + case <-rangeStartedCh: + case <-time.After(time.Second): + releaseOnce.Do(func() { close(releaseCh) }) + t.Fatalf("timed out waiting for parallel range requests to start") + } + } + + releaseOnce.Do(func() { close(releaseCh) }) + + if err := <-readDone; err != nil { + t.Fatalf("read: %v", err) + } + + wg.Wait() + + mu.Lock() + maxSeen := maxConcurrent + madeRanges := rangeRequests + mu.Unlock() + + if maxSeen > 2 { + t.Errorf("expected max 2 concurrent requests, got %d", maxSeen) + } + + if madeRanges == 0 { + t.Error("no range requests were made") + } +} + +// TestIfRangeValidation tests If-Range validation behavior. +func TestIfRangeValidation(t *testing.T) { + url := "https://example.com/if-range-test" + payload := testutil.GenerateTestData(100000) + etag := `"original-etag"` + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + ETag: etag, + }) + + // Change ETag on range requests to simulate resource change. + ft.ResponseHook = func(resp *http.Response) { + if resp.Request.Header.Get("Range") != "" { + // Check If-Range validation. + ifRange := resp.Request.Header.Get("If-Range") + if ifRange != etag { + // Resource changed, return full content. + resp.StatusCode = http.StatusOK + resp.Status = "200 OK" + resp.Header.Del("Content-Range") + resp.Body = io.NopCloser(bytes.NewReader(payload)) + } + } + } + + client := &http.Client{Transport: New(ft)} + + resp, err := client.Get(url) + if err != nil { + t.Fatalf("GET: %v", err) + } + defer resp.Body.Close() + + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read: %v", err) + } + + testutil.AssertDataEquals(t, got, payload) +} + +// TestNoContentLength_FallsBack tests fallback when Content-Length is +// missing. +func TestNoContentLength_FallsBack(t *testing.T) { + url := "https://example.com/no-length" + payload := testutil.GenerateTestData(100000) + + ft := testutil.NewFakeTransport() + ft.AddSimple(url, bytes.NewReader(payload), int64(len(payload)), true) + + // Remove Content-Length from HEAD response. + ft.ResponseHook = func(resp *http.Response) { + if resp.Request.Method == http.MethodHead { + resp.ContentLength = -1 + resp.Header.Del("Content-Length") + } + } + + client := &http.Client{Transport: New(ft)} + + resp, err := client.Get(url) + if err != nil { + t.Fatalf("GET: %v", err) + } + defer resp.Body.Close() + + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read: %v", err) + } + + testutil.AssertDataEquals(t, got, payload) + + // Should fall back to single request. + reqs := ft.GetRequests() + var rangeCount int + for _, req := range reqs { + if req.Header.Get("Range") != "" { + rangeCount++ + } + } + + if rangeCount != 0 { + t.Errorf("expected no range requests without Content-Length, got %d", + rangeCount) + } +} + +// TestNonGetRequest_PassesThrough verifies non-GET requests are passed +// through unmodified. +func TestNonGetRequest_PassesThrough(t *testing.T) { + url := "https://example.com/resource" + postData := []byte("post data") + responseData := []byte("response") + + ft := testutil.NewFakeTransport() + ft.AddSimple(url, bytes.NewReader(responseData), int64(len(responseData)), false) + + client := &http.Client{Transport: New(ft)} + + // Test POST request. + resp, err := client.Post(url, "application/json", + bytes.NewReader(postData)) + if err != nil { + t.Fatalf("POST failed: %v", err) + } + defer resp.Body.Close() + + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read failed: %v", err) + } + + testutil.AssertDataEquals(t, got, responseData) + + // Should not have any HEAD requests. + reqs := ft.GetRequests() + for _, req := range reqs { + if req.Method == http.MethodHead { + t.Error("unexpected HEAD request for non-GET method") + } + if req.Header.Get("Range") != "" { + t.Error("unexpected Range header for non-GET method") + } + } +} + +// TestWrongRangeResponse_HandlesError tests handling of incorrect range +// responses. +func TestWrongRangeResponse_HandlesError(t *testing.T) { + url := "https://example.com/wrong-range" + payload := testutil.GenerateTestData(100000) + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + }) + + // Return wrong range in response. + ft.ResponseHook = func(resp *http.Response) { + if resp.Request.Header.Get("Range") == "bytes=1000-1999" { + // Return different range than requested. + resp.Header.Set("Content-Range", "bytes 2000-2999/100000") + } + } + + client := &http.Client{Transport: New(ft)} + + // Make a specific range request. + req, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Fatalf("create request: %v", err) + } + req.Header.Set("Range", "bytes=1000-1999") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("GET failed: %v", err) + } + defer resp.Body.Close() + + // Should still work (parallel transport doesn't validate Content-Range + // for user requests). + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read failed: %v", err) + } + + // Should get the correct range data. + want := payload[1000:2000] + testutil.AssertDataEquals(t, got, want) +} + +// TestChunkBoundaries verifies correct chunk boundary calculation. +func TestChunkBoundaries(t *testing.T) { + url := "https://example.com/boundaries" + // Use specific size to test boundary conditions. + payload := testutil.GenerateTestData(10000) // Exactly 10KB. + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + }) + + client := &http.Client{ + Transport: New(ft, + WithMaxConcurrentPerRequest(4), + WithMinChunkSize(2500)), // Should result in 4 chunks of 2500 bytes. + } + + resp, err := client.Get(url) + if err != nil { + t.Fatalf("GET: %v", err) + } + defer resp.Body.Close() + + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read: %v", err) + } + + testutil.AssertDataEquals(t, got, payload) + + // Check the range requests. + reqs := ft.GetRequests() + + var actualRanges []string + for _, req := range reqs { + if r := req.Header.Get("Range"); r != "" && r != "bytes=0-0" { + actualRanges = append(actualRanges, r) + } + } + + // We might not get exactly these ranges due to scheduling, but verify we + // got multiple. + if len(actualRanges) < 2 { + t.Errorf("expected multiple range requests, got %d", len(actualRanges)) + } + + t.Logf("Actual ranges: %v", actualRanges) +} + +// TestETagChanged_FallsBackToSingle tests handling when ETag changes +// mid-download. +func TestETagChanged_FallsBackToSingle(t *testing.T) { + url := "https://example.com/changing" + payload := testutil.GenerateTestData(100000) + originalETag := `"original"` + changedETag := `"changed"` + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + ETag: originalETag, + }) + + requestCount := 0 + var mu sync.Mutex + ft.ResponseHook = func(resp *http.Response) { + mu.Lock() + requestCount++ + rc := requestCount + mu.Unlock() + // Change ETag after first request. + if rc > 1 && resp.Request.Header.Get("Range") != "" { + // Simulate resource change - return full content with new ETag. + resp.StatusCode = http.StatusOK + resp.Status = "200 OK" + resp.Header.Set("ETag", changedETag) + resp.Header.Del("Content-Range") + resp.Body = io.NopCloser(bytes.NewReader(payload)) + } + } + + client := &http.Client{Transport: New(ft)} + + resp, err := client.Get(url) + if err != nil { + t.Fatalf("GET: %v", err) + } + defer resp.Body.Close() + + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read: %v", err) + } + + // Should still get the full payload. + testutil.AssertDataEquals(t, got, payload) +} + +// TestNoValidator_StillWorks tests parallel download without ETag or +// Last-Modified. +func TestNoValidator_StillWorks(t *testing.T) { + url := "https://example.com/no-validator" + payload := testutil.GenerateTestData(100000) + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + // No ETag or LastModified. + }) + + client := &http.Client{Transport: New(ft)} + + resp, err := client.Get(url) + if err != nil { + t.Fatalf("GET: %v", err) + } + defer resp.Body.Close() + + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read: %v", err) + } + + testutil.AssertDataEquals(t, got, payload) + + // Check that no If-Range headers were sent. + headers := ft.GetRequestHeaders(url) + for _, h := range headers { + if ifRange := h.Get("If-Range"); ifRange != "" { + t.Errorf("unexpected If-Range header: %q", ifRange) + } + } +} + +// TestConditionalHeadersScrubbed verifies conditional headers are removed. +func TestConditionalHeadersScrubbed(t *testing.T) { + url := "https://example.com/conditional" + payload := testutil.GenerateTestData(100000) + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + ETag: `"test"`, + }) + + // Track headers and validate scrubbing for both HEAD and range GETs. + ft.RequestHook = func(req *http.Request) { + // For range requests made by parallel transport, + // conditional headers should be removed. + if req.Header.Get("Range") != "" { + if req.Header.Get("If-Match") != "" { + t.Errorf("%s request: If-Match header should be removed", + req.Method) + } + if req.Header.Get("If-None-Match") != "" { + t.Errorf("%s request: If-None-Match header should be removed", + req.Method) + } + if req.Header.Get("If-Modified-Since") != "" { + t.Errorf("%s request: If-Modified-Since header should be removed", + req.Method) + } + if req.Header.Get("If-Unmodified-Since") != "" { + t.Errorf("%s request: If-Unmodified-Since header should be removed", + req.Method) + } + } + // HEAD made by parallel transport should scrub conditional headers and + // force identity encoding. + if req.Method == http.MethodHead { + if req.Header.Get("If-Match") != "" || + req.Header.Get("If-None-Match") != "" || + req.Header.Get("If-Modified-Since") != "" || + req.Header.Get("If-Unmodified-Since") != "" { + t.Error("HEAD request should have conditional headers scrubbed") + } + if ae := req.Header.Get("Accept-Encoding"); ae != "identity" { + t.Errorf("HEAD should set Accept-Encoding=identity, got %q", ae) + } + } + // If-Range should only be present on range requests with proper value. + if ifRange := req.Header.Get("If-Range"); ifRange != "" { + if req.Header.Get("Range") == "" { + t.Error("If-Range without Range header") + } + } + } + + client := &http.Client{Transport: New(ft)} + + // Create request with conditional headers. + req, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Fatalf("create request: %v", err) + } + req.Header.Set("If-Match", `"wrong"`) + req.Header.Set("If-None-Match", `"also-wrong"`) + req.Header.Set("If-Modified-Since", "Wed, 21 Oct 2015 07:28:00 GMT") + req.Header.Set("If-Unmodified-Since", "Wed, 21 Oct 2015 07:28:00 GMT") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("GET: %v", err) + } + defer resp.Body.Close() + + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read: %v", err) + } + + testutil.AssertDataEquals(t, got, payload) +} + +// TestRangeHeader_PassesThrough verifies that requests with an explicit +// Range header are passed through without parallelization, and no HEAD +// request is issued by the transport. +func TestRangeHeader_PassesThrough(t *testing.T) { + url := "https://example.com/ranged" + payload := testutil.GenerateTestData(8192) + + ft := testutil.NewFakeTransport() + ft.AddSimple(url, bytes.NewReader(payload), int64(len(payload)), true) + + client := &http.Client{Transport: New(ft, WithMaxConcurrentPerRequest(4))} + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Fatalf("create request: %v", err) + } + req.Header.Set("Range", "bytes=1000-1999") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("GET failed: %v", err) + } + defer resp.Body.Close() + + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read failed: %v", err) + } + want := payload[1000:2000] + testutil.AssertDataEquals(t, got, want) + + // Ensure no HEAD was made and that only the user’s single GET with Range + // was sent (no extra parallel range requests). + reqs := ft.GetRequests() + var headCount, rangeGets int + for _, r := range reqs { + if r.Method == http.MethodHead { + headCount++ + } + if r.Method == http.MethodGet && r.Header.Get("Range") != "" { + rangeGets++ + } + } + if headCount != 0 { + t.Errorf("expected 0 HEAD requests, got %d", headCount) + } + if rangeGets != 1 { + t.Errorf("expected exactly 1 ranged GET, got %d", rangeGets) + } +} diff --git a/transport/resumable/transport.go b/transport/resumable/transport.go index 344a91c..2a45f3b 100644 --- a/transport/resumable/transport.go +++ b/transport/resumable/transport.go @@ -32,10 +32,11 @@ import ( "math" "math/rand" "net/http" - "strconv" "strings" "sync" "time" + + "github.com/docker/model-distribution/transport/internal/common" ) // Option configures a ResumableTransport. @@ -129,26 +130,20 @@ func isResumable(req *http.Request, resp *http.Response) bool { if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { return false } - if !supportsRange(resp.Header) { + if !common.SupportsRange(resp.Header) { return false } // Disallow when the response was auto-decompressed or has a Content-Encoding. if resp.Uncompressed || resp.Header.Get("Content-Encoding") != "" { return false } - return true -} - -// supportsRange determines whether or not an HTTP response indicates support -// for range requests. -func supportsRange(h http.Header) bool { - ar := strings.ToLower(h.Get("Accept-Ranges")) - for _, part := range strings.Split(ar, ",") { - if strings.TrimSpace(part) == "bytes" { - return true + // If the original request specified a Range, only support single-range. + if r := req.Header.Get("Range"); strings.TrimSpace(r) != "" { + if _, _, ok := common.ParseSingleRange(r); !ok { + return false } } - return false + return true } // resumableBody wraps a Response.Body to add transparent resume support. @@ -203,7 +198,7 @@ func newResumableBody(req *http.Request, resp *http.Response, tr *ResumableTrans } // Extract starting offsets from request Range if present (single-range only). - if start, end, ok := parseSingleRange(rb.originalRangeSpec); ok { + if start, end, ok := common.ParseSingleRange(rb.originalRangeSpec); ok { rb.initialStart = start if end >= 0 { rb.initialEnd = &end @@ -212,7 +207,7 @@ func newResumableBody(req *http.Request, resp *http.Response, tr *ResumableTrans // Refine offsets from Content-Range header if response was 206. if resp.StatusCode == http.StatusPartialContent { - if s, e, total, ok := parseContentRange(resp.Header.Get("Content-Range")); ok { + if s, e, total, ok := common.ParseContentRange(resp.Header.Get("Content-Range")); ok { rb.initialStart = s if e >= 0 { rb.initialEnd = &e @@ -221,13 +216,19 @@ func newResumableBody(req *http.Request, resp *http.Response, tr *ResumableTrans rb.totalSize = &total } } - } else if resp.ContentLength >= 0 { // 200 OK - total := int64(resp.ContentLength) - rb.totalSize = &total + } else if resp.StatusCode == http.StatusOK { + // For 200 OK, the server is sending a full stream starting at 0 + // regardless of any Range header on the request. + rb.initialStart = 0 + rb.initialEnd = nil + if resp.ContentLength >= 0 { + total := int64(resp.ContentLength) + rb.totalSize = &total + } } // Capture validators for If-Range to ensure consistency across resumes. - if et := resp.Header.Get("ETag"); et != "" && !isWeakETag(et) { + if et := resp.Header.Get("ETag"); et != "" && !common.IsWeakETag(et) { rb.etag = et } else if lm := resp.Header.Get("Last-Modified"); lm != "" { rb.lastModified = lm @@ -236,64 +237,97 @@ func newResumableBody(req *http.Request, resp *http.Response, tr *ResumableTrans } // Read delivers bytes to the caller. If an error occurs mid-stream, it will -// transparently try to resume by issuing a new Range request. +// transparently try to resume by issuing a new Range request. When the total +// length is unknown (e.g., 200 OK without Content-Length), completeness cannot +// be verified precisely; in such cases EOF is treated as the natural end. func (rb *resumableBody) Read(p []byte) (int, error) { - rb.mu.Lock() - defer rb.mu.Unlock() - - if rb.done { - return 0, io.EOF - } - if rb.rc == nil { - // No active body — must resume from the last delivered offset. - if err := rb.resume(rb.bytesRead); err != nil { - return 0, err + for { + // Snapshot state without holding the lock across I/O. + rb.mu.Lock() + if rb.done { + rb.mu.Unlock() + return 0, io.EOF + } + rc := rb.rc + planned, plannedOK := rb.plannedLength() + already := rb.bytesRead + rb.mu.Unlock() + + if rc == nil { + if err := rb.resume(already); err != nil { + return 0, err + } + continue } - } - n, err := rb.rc.Read(p) - rb.bytesRead += int64(n) - - switch { - case err == nil: - return n, nil - case errors.Is(err, io.EOF): - rb.done = true - return n, io.EOF - default: - // Underlying read failed mid-stream. Try to resume. - _ = rb.rc.Close() - rb.rc = nil + n, err := rc.Read(p) - if n > 0 { - // Surface bytes already read; the caller will call Read again. - return n, nil - } - if rb.retriesUsed >= rb.tr.maxRetries { - return 0, err - } - if rerr := rb.resume(rb.bytesRead); rerr != nil { - return 0, rerr - } + rb.mu.Lock() + rb.bytesRead += int64(n) - n2, err2 := rb.rc.Read(p) - rb.bytesRead += int64(n2) - if err2 == nil { - return n2, nil - } - if errors.Is(err2, io.EOF) { + switch { + case err == nil: + rb.mu.Unlock() + return n, nil + case errors.Is(err, io.EOF): + // If planned length is known and we are short, resume. + if plannedOK && already+int64(n) < planned { + _ = rb.rc.Close() + rb.rc = nil + if rb.retriesUsed >= rb.tr.maxRetries { + rb.mu.Unlock() + return n, io.ErrUnexpectedEOF + } + // Return bytes now; resume on next call. + if n > 0 { + rb.mu.Unlock() + return n, nil + } + // Resume outside lock. + nextOffset := rb.bytesRead + rb.mu.Unlock() + if rerr := rb.resume(nextOffset); rerr != nil { + return 0, rerr + } + continue + } + // Completed. rb.done = true + rb.mu.Unlock() + return n, io.EOF + default: + // Underlying read failed mid-stream. Try to resume. + _ = rb.rc.Close() + rb.rc = nil + + if n > 0 { + rb.mu.Unlock() + // Surface bytes already read; the caller will call Read again. + return n, nil + } + if rb.retriesUsed >= rb.tr.maxRetries { + rb.mu.Unlock() + return 0, err + } + off := rb.bytesRead + rb.mu.Unlock() + if rerr := rb.resume(off); rerr != nil { + return 0, rerr + } + continue } - return n2, err2 } } // Close closes the current response body if present. func (rb *resumableBody) Close() error { rb.mu.Lock() - defer rb.mu.Unlock() - if rb.rc != nil { - return rb.rc.Close() + rc := rb.rc + rb.rc = nil + rb.done = true + rb.mu.Unlock() + if rc != nil { + return rc.Close() } return nil } @@ -320,6 +354,12 @@ func (rb *resumableBody) resume(absoluteOffset int64) error { return err } + // For safety, do not attempt an unvalidated resume when neither a + // strong ETag nor Last-Modified validator is available. + if rb.etag == "" && rb.lastModified == "" { + return fmt.Errorf("resumable: cannot resume without validator") + } + start := rb.initialStart + absoluteOffset rangeVal := buildRangeHeader(start, rb.initialEnd) req := rb.cloneBaseRequest(rangeVal) @@ -339,13 +379,22 @@ func (rb *resumableBody) resume(absoluteOffset int64) error { switch resp.StatusCode { case http.StatusPartialContent: // Validate server honored our starting offset precisely. - s, _, _, ok := parseContentRange(resp.Header.Get("Content-Range")) + s, e, _, ok := common.ParseContentRange(resp.Header.Get("Content-Range")) if !ok || s != start { _ = resp.Body.Close() continue // try again; mismatched range } - rb.swapResponse(resp) + // If we requested a closed range and the end does not match, do + // not accept this response. + if rb.initialEnd != nil && e >= 0 && e != *rb.initialEnd { + _ = resp.Body.Close() + continue + } + // Install the new response under lock. + rb.mu.Lock() + rb.installResponseLocked(resp) rb.retriesUsed++ + rb.mu.Unlock() return nil case http.StatusOK: @@ -354,6 +403,12 @@ func (rb *resumableBody) resume(absoluteOffset int64) error { _ = resp.Body.Close() return fmt.Errorf("resumable: server returned 200 to a range request; resource may have changed") + case http.StatusMultipleChoices, http.StatusMovedPermanently, http.StatusFound, + http.StatusSeeOther, http.StatusNotModified, http.StatusUseProxy, + http.StatusTemporaryRedirect, http.StatusPermanentRedirect: + _ = resp.Body.Close() + return fmt.Errorf("resumable: resume received redirect status %d", resp.StatusCode) + case http.StatusRequestedRangeNotSatisfiable: // If we've already read to/ past the expected end, we are actually done. if rb.rangeIsComplete(absoluteOffset) { @@ -370,9 +425,9 @@ func (rb *resumableBody) resume(absoluteOffset int64) error { return fmt.Errorf("resumable: exceeded retry budget after %d attempts", rb.tr.maxRetries) } -// swapResponse replaces the current response body with a new one -// from a resumed request, and updates any validators and size info. -func (rb *resumableBody) swapResponse(resp *http.Response) { +// installResponseLocked installs resp as the current response and updates +// validators and size info. Caller must hold rb.mu. +func (rb *resumableBody) installResponseLocked(resp *http.Response) { if rb.rc != nil && rb.rc != resp.Body { _ = rb.rc.Close() } @@ -380,7 +435,7 @@ func (rb *resumableBody) swapResponse(resp *http.Response) { rb.rc = resp.Body // Persist validators from the server if they are strong. - if et := resp.Header.Get("ETag"); et != "" && !isWeakETag(et) { + if et := resp.Header.Get("ETag"); et != "" && !common.IsWeakETag(et) { rb.etag = et } if lm := resp.Header.Get("Last-Modified"); lm != "" { @@ -388,8 +443,8 @@ func (rb *resumableBody) swapResponse(resp *http.Response) { } // Merge any updated size info from the Content-Range. - if s, e, total, ok := parseContentRange(resp.Header.Get("Content-Range")); ok { - _ = s // start is validated by caller + if s, e, total, ok := common.ParseContentRange(resp.Header.Get("Content-Range")); ok { + _ = s // start validated by caller if e >= 0 { rb.initialEnd = &e } @@ -406,19 +461,17 @@ func (rb *resumableBody) cloneBaseRequest(rangeVal string) *http.Request { req := rb.origReq.Clone(rb.ctx) req.Body = nil req.ContentLength = 0 - req.Header = cloneHeader(rb.origReq.Header) + req.Header = rb.origReq.Header.Clone() // Ensure we control the Range validator set. req.Header.Set("Range", rangeVal) // Remove conditional headers that could conflict with If-Range semantics. - scrubConditionalHeaders(req.Header) + common.ScrubConditionalHeaders(req.Header) if rb.etag != "" { req.Header.Set("If-Range", rb.etag) } else if rb.lastModified != "" { req.Header.Set("If-Range", rb.lastModified) - } else { - // If no validator, we still attempt Range but risk a 200 if server can't verify. } // Prevent transparent decompression on resumed requests. @@ -426,28 +479,6 @@ func (rb *resumableBody) cloneBaseRequest(rangeVal string) *http.Request { return req } -// cloneHeader makes a deep copy of an http.Header map. -func cloneHeader(h http.Header) http.Header { - out := make(http.Header, len(h)) - for k, vv := range h { - cp := make([]string, len(vv)) - copy(cp, vv) - out[k] = cp - } - return out -} - -// scrubConditionalHeaders removes conditional headers we do not want to forward -// on resumed Range requests, because they can alter semantics or conflict with -// If-Range logic. -func scrubConditionalHeaders(h http.Header) { - h.Del("If-None-Match") - h.Del("If-Modified-Since") - h.Del("If-Match") - h.Del("If-Unmodified-Since") - // We overwrite Range/If-Range explicitly elsewhere. -} - // buildRangeHeader constructs a "Range" header value for a given start and // optional inclusive end. func buildRangeHeader(start int64, end *int64) string { @@ -467,8 +498,10 @@ func waitBackoff(ctx context.Context, bf BackoffFunc, attempt int) error { if d <= 0 { return nil } + t := time.NewTimer(d) + defer t.Stop() select { - case <-time.After(d): + case <-t.C: return nil case <-ctx.Done(): return ctx.Err() @@ -492,86 +525,3 @@ func (rb *resumableBody) rangeIsComplete(absoluteOffset int64) bool { } return false } - -// isWeakETag reports whether the ETag is a weak validator (W/"...") which must -// not be used with If-Range per RFC 7232 §2.1. -func isWeakETag(etag string) bool { - etag = strings.TrimSpace(etag) - return strings.HasPrefix(etag, "W/") || strings.HasPrefix(etag, "w/") -} - -// ─────────────────────────── Helpers: header parsing ────────────────────────── - -// parseSingleRange parses a single "Range: bytes=start-end" header. -// It returns (start, end, ok). When end is omitted, end == -1. -// -// Notes: -// - Only absolute-start forms are supported (no suffix ranges "-N"). -// - Multi-range specifications (comma separated) return ok == false. -func parseSingleRange(h string) (int64, int64, bool) { - if h == "" { - return 0, -1, false - } - h = strings.TrimSpace(h) - if !strings.HasPrefix(strings.ToLower(h), "bytes=") { - return 0, -1, false - } - spec := strings.TrimSpace(h[len("bytes="):]) - if strings.Contains(spec, ",") { - return 0, -1, false - } - parts := strings.SplitN(spec, "-", 2) - if len(parts) != 2 { - return 0, -1, false - } - if parts[0] == "" { - // Suffix form is not supported here. - return 0, -1, false - } - start, err := strconv.ParseInt(strings.TrimSpace(parts[0]), 10, 64) - if err != nil || start < 0 { - return 0, -1, false - } - end := int64(-1) - if strings.TrimSpace(parts[1]) != "" { - e, err := strconv.ParseInt(strings.TrimSpace(parts[1]), 10, 64) - if err != nil || e < start { - return 0, -1, false - } - end = e - } - return start, end, true -} - -// parseContentRange parses "Content-Range: bytes start-end/total". -// It returns (start, end, total, ok). When total is unknown, total == -1. -func parseContentRange(h string) (int64, int64, int64, bool) { - if h == "" { - return 0, -1, -1, false - } - h = strings.ToLower(strings.TrimSpace(h)) - if !strings.HasPrefix(h, "bytes ") { - return 0, -1, -1, false - } - body := strings.TrimSpace(h[len("bytes "):]) - seTotal := strings.SplitN(body, "/", 2) - if len(seTotal) != 2 { - return 0, -1, -1, false - } - se := strings.SplitN(strings.TrimSpace(seTotal[0]), "-", 2) - if len(se) != 2 { - return 0, -1, -1, false - } - start, err1 := strconv.ParseInt(strings.TrimSpace(se[0]), 10, 64) - end, err2 := strconv.ParseInt(strings.TrimSpace(se[1]), 10, 64) - totalStr := strings.TrimSpace(seTotal[1]) - var total int64 = -1 - var err3 error - if totalStr != "*" { - total, err3 = strconv.ParseInt(totalStr, 10, 64) - } - if err1 != nil || err2 != nil || (err3 != nil && totalStr != "*") { - return 0, -1, -1, false - } - return start, end, total, true -} diff --git a/transport/resumable/transport_test.go b/transport/resumable/transport_test.go index 97bb994..7b3262b 100644 --- a/transport/resumable/transport_test.go +++ b/transport/resumable/transport_test.go @@ -2,1194 +2,1385 @@ package resumable import ( "bytes" + "fmt" "io" "net/http" - "strconv" "strings" "sync" "testing" "time" + + testutil "github.com/docker/model-distribution/transport/internal/testing" ) -// ───────────────────────── Test Harness Types & Utilities ───────────────────────── - -// flakePlan specifies the behavior of the fake transport for a single URL. -// It allows tests to deterministically exercise success and error paths. -type flakePlan struct { - // CutAfter defines, for each served segment, the number of bytes to deliver - // before injecting a read error. Segment index 0 corresponds to the initial - // request (without a Range header). Segment index 1 corresponds to the first - // resume attempt, and so on. A value of -1 means "no failure for that segment". - // If the slice does not include an entry for a segment, that segment will - // default to no failure. - CutAfter []int - - // ForceNon206OnResume indicates that the server should ignore Range headers - // on resume requests and respond with HTTP 200 OK (full body). This is used - // to verify that the client rejects non-206 responses during resume. - ForceNon206OnResume bool - - // WrongStartOnResume indicates that the server should respond with HTTP 206 - // Partial Content but a Content-Range whose start is strictly different from - // the requested start (we skew it forward by +1). This should be rejected by - // the client as unsafe. - WrongStartOnResume bool - - // NoRangeSupport indicates that the server does not support range requests - // (no Accept-Ranges header and never returns 206). This ensures the wrapper - // pass-through behavior is exercised. - NoRangeSupport bool - - // RequireIfRange indicates that the server requires a valid If-Range header - // on resumed requests. The expected validator is the current strong ETag, if - // present; otherwise the current Last-Modified value. If the header is missing - // or does not match, the server returns HTTP 200 OK with the full body. - RequireIfRange bool - - // ChangeETagOnResume indicates that the server mutates its ETag for resumed - // requests. This is used to simulate resource changes between segments. - ChangeETagOnResume bool - - // ChangeLastModifiedOnResume indicates that the server mutates its - // Last-Modified timestamp for resumed requests. This also simulates a change. - ChangeLastModifiedOnResume bool - - // OmitETag indicates that the server should omit the ETag header from - // responses. This forces clients to fall back to Last-Modified for If-Range. - OmitETag bool - - // OmitLastModified indicates that the server should omit the Last-Modified - // header from responses. - OmitLastModified bool - - // InitialContentEncoding, when non-empty, is set as the Content-Encoding - // of the initial (non-Range) response to simulate compressed delivery. - InitialContentEncoding string +// blockingBody simulates a response body that blocks on Read until closed. +type blockingBody struct { + ch chan struct{} } -// fakeTransport is a deterministic, concurrency-safe test double that implements -// http.RoundTripper. It serves byte slices from an in-memory map and can emulate -// flakiness and protocol misbehaviors based on a per-URL flakePlan. -type fakeTransport struct { - mu sync.Mutex // guards all fields below - - // resources maps absolute URL strings to the byte content that will be served. - resources map[string][]byte - - // plans maps absolute URL strings to their associated behavioral plan. - plans map[string]*flakePlan +func newBlockingBody() *blockingBody { return &blockingBody{ch: make(chan struct{})} } +func (b *blockingBody) Read(p []byte) (int, error) { <-b.ch; return 0, io.EOF } +func (b *blockingBody) Close() error { close(b.ch); return nil } - // etags maps absolute URL strings to the canonical (usually STRONG) ETag that - // represents the current version for the initial request (segment 0). - etags map[string]string - - // lastModified maps absolute URL strings to the Last-Modified timestamp value - // for the initial request (segment 0), formatted per RFC 7231. - lastModified map[string]string +// TestResumeSingleFailure_Succeeds tests resuming after a single failure. +func TestResumeSingleFailure_Succeeds(t *testing.T) { + url := "https://example.com/test-file" + payload := testutil.GenerateTestData(5000) - // seg tracks how many segments have been served per URL. The initial request - // uses segment 0, the first resume uses segment 1, and so forth. - seg map[string]int + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + ETag: `"test-etag"`, + }) - // lastReqHeaders stores a copy of request headers for each segment per URL, - // allowing tests to assert on what the client sent (e.g., Accept-Encoding). - lastReqHeaders map[string][]http.Header -} + // Simulate failure after 2500 bytes on first request. + ft.SetFailAfter(url, 2500) -// newFakeTransport constructs and returns a new fakeTransport with all internal -// maps initialized. It is ready for use as an http.RoundTripper. -func newFakeTransport() *fakeTransport { - return &fakeTransport{ - resources: make(map[string][]byte), - plans: make(map[string]*flakePlan), - etags: make(map[string]string), - lastModified: make(map[string]string), - seg: make(map[string]int), - lastReqHeaders: make(map[string][]http.Header), + client := &http.Client{ + Transport: New(ft, WithMaxRetries(3)), } -} -// add registers a new URL, its byte payload, and its behavior plan with the -// fake transport. The ETag is STRONG by default; Last-Modified is fixed. -func (ft *fakeTransport) add(url string, data []byte, plan *flakePlan) { - ft.mu.Lock() - defer ft.mu.Unlock() - ft.resources[url] = data - ft.plans[url] = plan - // Strong ETag to match client logic (If-Range only with strong ETags by spec). - ft.etags[url] = `"` + strings.ReplaceAll(url, "/", "_") + `"` - ft.lastModified[url] = time.Unix(1_700_000_000, 0).UTC().Format(http.TimeFormat) -} + resp, err := client.Get(url) + if err != nil { + t.Fatalf("GET failed: %v", err) + } + defer resp.Body.Close() -// segmentHeaders returns copies of the headers for each segment requested for url. -func (ft *fakeTransport) segmentHeaders(url string) []http.Header { - ft.mu.Lock() - defer ft.mu.Unlock() - hs := ft.lastReqHeaders[url] - out := make([]http.Header, len(hs)) - for i := range hs { - out[i] = cloneHeader(hs[i]) + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read failed: %v", err) } - return out -} -// RoundTrip implements http.RoundTripper for fakeTransport. It interprets the -// incoming request, consults the configured plan for the target URL, and returns -// an HTTP response that adheres to the requested scenario (including failures). -func (ft *fakeTransport) RoundTrip(req *http.Request) (*http.Response, error) { - // Record the headers we received for this URL/segment for later assertions. - rurl := req.URL.String() - ft.mu.Lock() - ft.lastReqHeaders[rurl] = append(ft.lastReqHeaders[rurl], cloneHeader(req.Header)) - - data, ok := ft.resources[rurl] - plan := ft.plans[rurl] - etag := ft.etags[rurl] - lm := ft.lastModified[rurl] - if !ok { - ft.mu.Unlock() - return &http.Response{StatusCode: http.StatusNotFound, Body: io.NopCloser(bytes.NewReader(nil)), Request: req}, nil - } - seg := ft.seg[rurl] - ft.seg[rurl] = seg + 1 - ft.mu.Unlock() - - total := int64(len(data)) - rangeHdr := req.Header.Get("Range") - supportsRange := plan == nil || !plan.NoRangeSupport - - // Compute validators for this segment (may be omitted or mutated by the plan). - curETag := etag - curLM := lm - if plan != nil { - if plan.OmitETag { - curETag = "" - } - if plan.OmitLastModified { - curLM = "" - } - if rangeHdr != "" { // resume only - if plan.ChangeETagOnResume && curETag != "" { - curETag = curETag + "-changed" - } - if plan.ChangeLastModifiedOnResume && curLM != "" { - parsed, _ := time.Parse(http.TimeFormat, lm) // Safe default - curLM = parsed.Add(1 * time.Second).UTC().Format(http.TimeFormat) - } + testutil.AssertDataEquals(t, got, payload) + + // Verify resume happened. + reqs := ft.GetRequests() + var rangeRequests int + for _, req := range reqs { + if req.Header.Get("Range") != "" { + rangeRequests++ + t.Logf("Range request: %s", req.Header.Get("Range")) } } - // Determine cut-off point for this segment (if any). - cutAfter := -1 - if plan != nil && seg < len(plan.CutAfter) { - cutAfter = plan.CutAfter[seg] + if rangeRequests < 1 { + t.Error("expected at least one range request for resume") } +} - // Helper to build a body that fails after N bytes, if requested. - makeBody := func(b []byte, cut int) io.ReadCloser { - if cut < 0 || cut >= len(b) { - return io.NopCloser(bytes.NewReader(b)) - } - return newFlakyReader(b, cut) - } - - // Initial request (no Range header): - if rangeHdr == "" { - if !supportsRange { - return &http.Response{ - Status: "200 OK", - StatusCode: http.StatusOK, - Header: cloneHeader(http.Header{}), - ContentLength: total, - Body: makeBody(data, cutAfter), - Request: req, - }, nil - } - h := http.Header{} - h.Set("Accept-Ranges", "bytes") - if curETag != "" { - h.Set("ETag", curETag) - } - if curLM != "" { - h.Set("Last-Modified", curLM) - } - if plan != nil && plan.InitialContentEncoding != "" { - h.Set("Content-Encoding", plan.InitialContentEncoding) +// TestResumeMultipleFailuresWithinBudget_Succeeds tests multiple resume +// attempts. +func TestResumeMultipleFailuresWithinBudget_Succeeds(t *testing.T) { + url := "https://example.com/multi-fail" + payload := testutil.GenerateTestData(10000) + + ft := testutil.NewFakeTransport() + + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + ETag: `"multi-fail-etag"`, + }) + + // Hook to inject failures - use SetFailAfter multiple times. + failurePoints := []int{2000, 5000, 7500} + failureIndex := 0 + requestCount := 0 + var mu sync.Mutex + ft.ResponseHook = func(resp *http.Response) { + if resp.Request.Method == http.MethodGet && + failureIndex < len(failurePoints) { + // For non-range requests, inject failure. + if resp.Request.Header.Get("Range") == "" { + mu.Lock() + idx := failureIndex + failureIndex++ + mu.Unlock() + resp.Body = testutil.NewFlakyReader( + bytes.NewReader(payload), + int64(len(payload)), + failurePoints[idx], + ) + } else { + // For range requests, check which failure point we're at. + mu.Lock() + requestCount++ + rc := requestCount + fi := failureIndex + mu.Unlock() + if rc <= len(failurePoints) && + fi < len(failurePoints) { + // Parse range to determine data slice. + rangeHeader := resp.Request.Header.Get("Range") + if rangeHeader != "" { + // Simple parsing for bytes=N- format. + var start int + fmt.Sscanf(rangeHeader, "bytes=%d-", &start) + rangeData := payload[start:] + + // Apply next failure point relative to this + // range. + nextFailure := failurePoints[fi] - start + if nextFailure > 0 && + nextFailure < len(rangeData) { + resp.Body = testutil.NewFlakyReader( + bytes.NewReader(rangeData), + int64(len(rangeData)), + nextFailure, + ) + mu.Lock() + failureIndex++ + mu.Unlock() + } + } + } + } } - return &http.Response{ - Status: "200 OK", - StatusCode: http.StatusOK, - Header: h, - ContentLength: total, - Body: makeBody(data, cutAfter), - Request: req, - }, nil } - // Resume (Range present): - if !supportsRange { - return &http.Response{Status: "200 OK", StatusCode: http.StatusOK, Header: http.Header{}, ContentLength: total, Body: makeBody(data, cutAfter), Request: req}, nil + client := &http.Client{ + Transport: New(ft, WithMaxRetries(5)), } - // Parse the Range header (bytes=start[-end]). - var start, end int64 = 0, total - 1 - if !strings.HasPrefix(strings.ToLower(rangeHdr), "bytes=") { - return &http.Response{StatusCode: http.StatusRequestedRangeNotSatisfiable, Body: io.NopCloser(bytes.NewReader(nil)), Request: req}, nil - } - spec := strings.TrimSpace(rangeHdr[len("bytes="):]) - parts := strings.SplitN(spec, "-", 2) - if len(parts) != 2 || parts[0] == "" { - return &http.Response{StatusCode: http.StatusRequestedRangeNotSatisfiable, Body: io.NopCloser(bytes.NewReader(nil)), Request: req}, nil + resp, err := client.Get(url) + if err != nil { + t.Fatalf("GET failed: %v", err) } - var err error - start, err = strconv.ParseInt(parts[0], 10, 64) + defer resp.Body.Close() + + got, err := io.ReadAll(resp.Body) if err != nil { - return &http.Response{StatusCode: http.StatusRequestedRangeNotSatisfiable, Body: io.NopCloser(bytes.NewReader(nil)), Request: req}, nil + t.Fatalf("read failed: %v", err) } - if parts[1] != "" { - end, err = strconv.ParseInt(parts[1], 10, 64) - if err != nil || end < start { - return &http.Response{StatusCode: http.StatusRequestedRangeNotSatisfiable, Body: io.NopCloser(bytes.NewReader(nil)), Request: req}, nil + + testutil.AssertDataEquals(t, got, payload) + + // Check that multiple resumes happened. + reqs := ft.GetRequests() + var rangeCount int + for _, req := range reqs { + if req.Header.Get("Range") != "" { + rangeCount++ } } - if start >= total { - h := http.Header{} - h.Set("Content-Range", "bytes */"+strconv.FormatInt(total, 10)) - return &http.Response{StatusCode: http.StatusRequestedRangeNotSatisfiable, Header: h, Body: io.NopCloser(bytes.NewReader(nil)), Request: req}, nil - } - if end >= total { - end = total - 1 - } - // If-Range enforcement (server side), if requested by the plan. - if plan != nil && plan.RequireIfRange { - // If the advertised ETag is weak, treat it as unusable for If-Range and - // require Last-Modified instead (aligns with RFC 7232/7233 and client logic). - expected := curETag - if expected == "" || isWeakETag(expected) { - expected = curLM - } - ir := req.Header.Get("If-Range") - if expected == "" || ir == "" || ir != expected { - h := http.Header{} - h.Set("Accept-Ranges", "bytes") - if curETag != "" { - h.Set("ETag", curETag) - } - if curLM != "" { - h.Set("Last-Modified", curLM) - } - return &http.Response{Status: "200 OK", StatusCode: http.StatusOK, Header: h, ContentLength: total, Body: makeBody(data, cutAfter), Request: req}, nil - } + if rangeCount < 2 { + t.Errorf("expected at least 2 range requests, got %d", rangeCount) } +} - if plan != nil && plan.ForceNon206OnResume { - h := http.Header{} - h.Set("Accept-Ranges", "bytes") - if curETag != "" { - h.Set("ETag", curETag) - } - if curLM != "" { - h.Set("Last-Modified", curLM) +// TestExceedRetryBudget_Fails tests failure when retry budget is exceeded. +func TestExceedRetryBudget_Fails(t *testing.T) { + url := "https://example.com/too-many-failures" + payload := testutil.GenerateTestData(4096) + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + ETag: `"fail-test"`, + }) + + // Always fail after 100 bytes. + ft.ResponseHook = func(resp *http.Response) { + if resp.Request.Method == http.MethodGet { + resp.Body = testutil.NewFlakyReader(bytes.NewReader(payload), int64(len(payload)), 100) } - return &http.Response{Status: "200 OK", StatusCode: http.StatusOK, Header: h, ContentLength: total, Body: makeBody(data, cutAfter), Request: req}, nil } - // Construct 206 Partial Content. Optionally skew the start to simulate a bad range. - respStart := start - if plan != nil && plan.WrongStartOnResume { - respStart = start + 1 - if respStart > end { - respStart = end - } + client := &http.Client{ + Transport: New(ft, WithMaxRetries(2)), // Low retry limit. } - chunk := data[respStart : end+1] - h := http.Header{} - h.Set("Accept-Ranges", "bytes") - if curETag != "" { - h.Set("ETag", curETag) + + resp, err := client.Get(url) + if err != nil { + t.Fatalf("GET failed: %v", err) } - if curLM != "" { - h.Set("Last-Modified", curLM) + defer resp.Body.Close() + + _, err = io.ReadAll(resp.Body) + if err == nil { + t.Error("expected error after exceeding retry budget") } - h.Set("Content-Range", "bytes "+strconv.FormatInt(respStart, 10)+"-"+strconv.FormatInt(end, 10)+"/"+strconv.FormatInt(total, 10)) - return &http.Response{Status: "206 Partial Content", StatusCode: http.StatusPartialContent, Header: h, ContentLength: int64(len(chunk)), Body: makeBody(chunk, cutAfter), Request: req}, nil -} -// flakyReader is an io.ReadCloser that serves a byte slice and injects a read -// error (io.ErrUnexpectedEOF) after a configured number of bytes. -type flakyReader struct { - // data is the payload to be served to the client. - data []byte - // cutAfter is the absolute number of bytes to deliver before injecting a failure. - cutAfter int - // pos is the current read offset into data. - pos int - // closed reports whether Close() has been called; further reads error. - closed bool -} + // Check that retries were attempted. + reqs := ft.GetRequests() + var attempts int + for _, req := range reqs { + if req.Method == http.MethodGet { + attempts++ + } + } -func newFlakyReader(data []byte, cutAfter int) io.ReadCloser { - return &flakyReader{data: data, cutAfter: cutAfter} + // Initial + 2 retries = 3 total. + if attempts < 2 { + t.Errorf("expected at least 2 GET attempts, got %d", attempts) + } } -func (fr *flakyReader) Read(p []byte) (int, error) { - if fr.closed { - return 0, io.ErrClosedPipe - } - if fr.pos >= len(fr.data) { - return 0, io.EOF +// TestReadCloseInterleaving ensures Close does not deadlock with a blocked Read +// and unblocks promptly. +func TestReadCloseInterleaving(t *testing.T) { + url := "https://example.com/blocking" + payload := testutil.GenerateTestData(1024) + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + ETag: `"etag"`, + }) + // Replace body with a blocking body for the initial GET. + bb := newBlockingBody() + ft.ResponseHook = func(resp *http.Response) { + if resp.Request.Method == http.MethodGet && resp.Request.Header.Get("Range") == "" { + resp.Body = bb + } } - remain := len(fr.data) - fr.pos - n := len(p) - if n > remain { - n = remain + + client := &http.Client{Transport: New(ft)} + resp, err := client.Get(url) + if err != nil { + t.Fatalf("GET: %v", err) } - if fr.pos < fr.cutAfter && fr.cutAfter < len(fr.data) { - max := fr.cutAfter - fr.pos - if max <= 0 { - return 0, io.ErrUnexpectedEOF - } - if n > max { - n = max - } - copy(p[:n], fr.data[fr.pos:fr.pos+n]) - fr.pos += n - if fr.pos >= fr.cutAfter { - return n, io.ErrUnexpectedEOF - } - return n, nil + + done := make(chan struct{}) + go func() { + defer close(done) + _, _ = io.ReadAll(resp.Body) + }() + + // Close should unblock the read goroutine promptly. + if err := resp.Body.Close(); err != nil { + t.Fatalf("close: %v", err) } - copy(p[:n], fr.data[fr.pos:fr.pos+n]) - fr.pos += n - if fr.pos >= len(fr.data) { - return n, io.EOF + select { + case <-done: + case <-time.After(1 * time.Second): + t.Fatal("read did not unblock after Close") } - return n, nil } -func (fr *flakyReader) Close() error { fr.closed = true; return nil } +// TestMultiRange_PassThrough ensures multi-range requests are not wrapped. +func TestMultiRange_PassThrough(t *testing.T) { + url := "https://example.com/multirange" + payload := testutil.GenerateTestData(4096) -// newClient is a small helper to build an http.Client with our resumable transport. -func newClient(rt http.RoundTripper, retries int) *http.Client { - return &http.Client{Transport: New(rt, WithMaxRetries(retries), WithBackoff(func(int) time.Duration { return 0 }))} -} + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + }) -// ─────────────────────────────────── Tests ─────────────────────────────────── + client := &http.Client{Transport: New(ft)} + req, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Range", "bytes=0-10,20-30") -// TestResumeSingleFailure_Succeeds verifies that a single mid-stream failure on -// the initial response is successfully recovered by a single resume attempt, -// and that the resulting assembled payload matches exactly. -func TestResumeSingleFailure_Succeeds(t *testing.T) { - // Arrange: create payload and fake transport that cuts once on the initial segment. - url := "https://example.com/blob" - payload := bytes.Repeat([]byte("abcde"), 1_000) // 5,000 bytes - ft := newFakeTransport() - ft.add(url, payload, &flakePlan{CutAfter: []int{1000, -1}}) - - // Act: issue GET and read to completion through resumable transport. - client := newClient(ft, 3) - resp, err := client.Get(url) + resp, err := client.Do(req) if err != nil { t.Fatalf("GET: %v", err) } - t.Cleanup(func() { resp.Body.Close() }) + defer resp.Body.Close() - // Assert Content-Length for initial 200 - if resp.ContentLength != int64(len(payload)) { - t.Errorf("ContentLength = %d, want %d", resp.ContentLength, len(payload)) + // FakeTransport does not implement multi-range; it returns 400. + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400 from fake transport for multi-range, got %d", resp.StatusCode) } - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read: %v", err) + // Ensure no If-Range was injected on request headers. + hdrs := ft.GetRequestHeaders(url) + for _, h := range hdrs { + if h.Get("If-Range") != "" { + t.Error("unexpected If-Range header on multi-range request") + } } +} - // Assert: reconstructed body matches original payload. - if !bytes.Equal(got, payload) { - t.Errorf("payload mismatch: got %d bytes, want %d", len(got), len(payload)) +// TestInitialRange_200OK_Ignored ensures if server responds 200 to a ranged +// request, the stream is treated as starting at 0 and reads succeed. +func TestInitialRange_200OK_Ignored(t *testing.T) { + url := "https://example.com/range-ignored" + payload := testutil.GenerateTestData(2048) + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + ETag: `"e"`, + }) + + // Force 200 full response even when Range is present. + ft.ResponseHook = func(resp *http.Response) { + if resp.Request.Header.Get("Range") != "" && resp.StatusCode == http.StatusPartialContent { + resp.StatusCode = http.StatusOK + resp.Status = "200 OK" + resp.Header.Del("Content-Range") + resp.Body = io.NopCloser(bytes.NewReader(payload)) + } } -} -// TestResumeMultipleFailuresWithinBudget_Succeeds verifies that multiple -// consecutive mid-stream failures are handled as long as the retry budget is -// sufficient, resulting in a fully correct payload. -func TestResumeMultipleFailuresWithinBudget_Succeeds(t *testing.T) { - // Arrange - url := "https://example.com/multi" - payload := bytes.Repeat([]byte{0x42}, 10_000) - ft := newFakeTransport() - ft.add(url, payload, &flakePlan{CutAfter: []int{500, 700, -1}}) - - // Act - client := newClient(ft, 5) - resp, err := client.Get(url) + client := &http.Client{Transport: New(ft)} + req, _ := http.NewRequest("GET", url, nil) + req.Header.Set("Range", "bytes=100-199") + resp, err := client.Do(req) if err != nil { t.Fatalf("GET: %v", err) } - t.Cleanup(func() { resp.Body.Close() }) + defer resp.Body.Close() - // Assert Content-Length for initial 200 - if resp.ContentLength != int64(len(payload)) { - t.Errorf("ContentLength = %d, want %d", resp.ContentLength, len(payload)) - } - - got, err := io.ReadAll(resp.Body) + data, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("read: %v", err) } + testutil.AssertDataEquals(t, data, payload) +} - // Assert - if !bytes.Equal(got, payload) { - t.Errorf("payload mismatch: got %d bytes, want %d", len(got), len(payload)) +// TestRedirectOnResume returns 3xx for resume request and expects a clear error. +func TestRedirectOnResume(t *testing.T) { + url := "https://example.com/redirect-on-resume" + payload := testutil.GenerateTestData(5000) + etag := `"strong"` + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + ETag: etag, + }) + ft.SetFailAfter(url, 2500) + + ft.ResponseHook = func(resp *http.Response) { + if resp.Request.Header.Get("Range") != "" { + resp.StatusCode = http.StatusFound + resp.Status = "302 Found" + resp.Header.Del("Content-Range") + resp.Body = io.NopCloser(bytes.NewReader(nil)) + } } -} -// TestExceedRetryBudget_Fails verifies that when the number of consecutive failures -// exceeds the configured retry budget, the read ultimately fails with an error. -func TestExceedRetryBudget_Fails(t *testing.T) { - // Arrange - url := "https://example.com/toosad" - payload := bytes.Repeat([]byte{0x99}, 4_096) - ft := newFakeTransport() - ft.add(url, payload, &flakePlan{CutAfter: []int{1000, 1000, 1000, 1000, 1000}}) - - // Act - client := newClient(ft, 2) + client := &http.Client{Transport: New(ft, WithMaxRetries(2))} resp, err := client.Get(url) if err != nil { t.Fatalf("GET: %v", err) } - t.Cleanup(func() { resp.Body.Close() }) + defer resp.Body.Close() - // Assert Content-Length for initial 200 - if resp.ContentLength != int64(len(payload)) { - t.Errorf("ContentLength = %d, want %d", resp.ContentLength, len(payload)) + _, err = io.ReadAll(resp.Body) + if err == nil || !strings.Contains(err.Error(), "redirect status") { + t.Fatalf("expected redirect error, got %v", err) } +} - _, rerr := io.ReadAll(resp.Body) +// TestWrongStartOnResume_IsRejected tests handling of unexpected range +// responses. +func TestWrongStartOnResume_IsRejected(t *testing.T) { + url := "https://example.com/wrong-start" + payload := testutil.GenerateTestData(5000) + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + ETag: `"test"`, + }) + + // Return wrong range on resume. + resumeAttempted := false + var muResume sync.Mutex + ft.ResponseHook = func(resp *http.Response) { + if resp.Request.Header.Get("Range") == "bytes=2500-" { + muResume.Lock() + resumeAttempted = true + muResume.Unlock() + // Return wrong start position. + resp.Header.Set("Content-Range", "bytes 3000-4999/5000") + resp.Body = io.NopCloser(testutil.NewFlakyReader( + bytes.NewReader(payload[3000:]), + int64(len(payload[3000:])), + 0, + )) + } + } + + // First fail after 2500 bytes. + ft.SetFailAfter(url, 2500) - // Assert - if rerr == nil { - t.Errorf("expected read error after exceeding retry budget, got nil") + client := &http.Client{ + Transport: New(ft, WithMaxRetries(3)), } -} -// TestWrongStartOnResume_IsRejected verifies that the client rejects a resume -// response whose Content-Range start differs from the requested start. -func TestWrongStartOnResume_IsRejected(t *testing.T) { - // Arrange - url := "https://example.com/wrongstart" - payload := bytes.Repeat([]byte("XYZ"), 3000) - ft := newFakeTransport() - ft.add(url, payload, &flakePlan{CutAfter: []int{1000, -1}, WrongStartOnResume: true}) - - // Act - client := newClient(ft, 2) resp, err := client.Get(url) if err != nil { t.Fatalf("GET: %v", err) } - t.Cleanup(func() { resp.Body.Close() }) + defer resp.Body.Close() - // Assert Content-Length for initial 200 - if resp.ContentLength != int64(len(payload)) { - t.Errorf("ContentLength = %d, want %d", resp.ContentLength, len(payload)) + _, err = io.ReadAll(resp.Body) + if err == nil { + t.Error("expected error due to wrong range start") } - _, rerr := io.ReadAll(resp.Body) - - // Assert - if rerr == nil { - t.Errorf("expected read error due to wrong Content-Range start, got nil") + muResume.Lock() + attempted := resumeAttempted + muResume.Unlock() + if !attempted { + t.Error("resume was not attempted") } } -// TestNon206OnResume_IsRejected verifies that a non-206 response to a resume -// request is rejected and ultimately causes the read to fail. +// TestNon206OnResume_IsRejected tests handling when server returns 200 +// instead of 206. func TestNon206OnResume_IsRejected(t *testing.T) { - // Arrange - url := "https://example.com/non206" - payload := bytes.Repeat([]byte{0xAA, 0xBB, 0xCC}, 2000) - ft := newFakeTransport() - ft.add(url, payload, &flakePlan{CutAfter: []int{1000, -1}, ForceNon206OnResume: true}) - - // Act - client := newClient(ft, 2) + url := "https://example.com/non-206" + payload := testutil.GenerateTestData(5000) + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + ETag: `"test"`, + }) + + // Return 200 on range request (simulating resource change). + ft.ResponseHook = func(resp *http.Response) { + if resp.Request.Header.Get("Range") == "bytes=2500-" { + resp.StatusCode = http.StatusOK + resp.Status = "200 OK" + resp.Header.Del("Content-Range") + resp.Body = io.NopCloser(testutil.NewFlakyReader( + bytes.NewReader(payload), + int64(len(payload)), + 0, + )) + } + } + + ft.SetFailAfter(url, 2500) + + client := &http.Client{ + Transport: New(ft, WithMaxRetries(3)), + } + resp, err := client.Get(url) if err != nil { t.Fatalf("GET: %v", err) } - t.Cleanup(func() { resp.Body.Close() }) + defer resp.Body.Close() - // Assert Content-Length for initial 200 - if resp.ContentLength != int64(len(payload)) { - t.Errorf("ContentLength = %d, want %d", resp.ContentLength, len(payload)) + _, err = io.ReadAll(resp.Body) + if err == nil || + err.Error() != "resumable: server returned 200 to a range request; resource may have changed" { + t.Errorf("expected specific error, got: %v", err) } +} + +// TestNoRangeSupport_PassesThrough_NoResume tests fallback when server +// doesn't support ranges. +func TestNoRangeSupport_PassesThrough_NoResume(t *testing.T) { + url := "https://example.com/no-range" + payload := testutil.GenerateTestData(5000) - _, rerr := io.ReadAll(resp.Body) + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: false, // No range support. + }) - // Assert - if rerr == nil { - t.Errorf("expected read error due to non-206 on resume, got nil") + // Simulate failure - should not be able to resume. + ft.SetFailAfter(url, 2500) + + client := &http.Client{ + Transport: New(ft, WithMaxRetries(3)), } -} -// TestNoRangeSupport_PassesThrough_NoResume verifies that when the server does -// not advertise range support, the wrapper does not attempt to resume and the -// mid-stream error bubbles up to the caller. -func TestNoRangeSupport_PassesThrough_NoResume(t *testing.T) { - // Arrange - url := "https://example.com/norange" - payload := bytes.Repeat([]byte("hello"), 500) - ft := newFakeTransport() - ft.add(url, payload, &flakePlan{NoRangeSupport: true, CutAfter: []int{200}}) - - // Act - client := newClient(ft, 3) resp, err := client.Get(url) if err != nil { - t.Fatalf("GET: %v", err) + t.Fatalf("GET failed: %v", err) } - t.Cleanup(func() { resp.Body.Close() }) + defer resp.Body.Close() - // Assert Content-Length for initial 200 - if resp.ContentLength != int64(len(payload)) { - t.Errorf("ContentLength = %d, want %d", resp.ContentLength, len(payload)) - } - - buf := make([]byte, 1<<20) - _, rerr := resp.Body.Read(buf) - if rerr == nil { - _, rerr = io.ReadAll(resp.Body) + got, err := io.ReadAll(resp.Body) + if err == nil { + t.Error("expected read error due to no range support and failure") } - // Assert - if rerr == nil { - t.Errorf("expected mid-stream error without resume support") + // Should only get partial data. + if len(got) >= len(payload) { + t.Errorf("got %d bytes, expected less than %d", len(got), len(payload)) } } -// TestIfRange_ETag_Matches_AllowsResume verifies that when the server requires -// If-Range and provides a strong ETag, the client sends the correct validator -// and the resume succeeds. +// TestIfRange_ETag_Matches_AllowsResume tests If-Range with ETag validation. func TestIfRange_ETag_Matches_AllowsResume(t *testing.T) { - // Arrange - url := "https://example.com/ifrange-etag-ok" - payload := bytes.Repeat([]byte("data-"), 1500) // 7,500 bytes - ft := newFakeTransport() - ft.add(url, payload, &flakePlan{CutAfter: []int{1200, -1}, RequireIfRange: true}) - - // Act - client := newClient(ft, 3) - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET: %v", err) + url := "https://example.com/if-range-etag" + payload := testutil.GenerateTestData(7500) + etag := `"strong-etag"` + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + ETag: etag, + }) + + // Simulate failure to trigger resume. + failCount := 0 + var muFail sync.Mutex + ft.ResponseHook = func(resp *http.Response) { + muFail.Lock() + fc := failCount + if resp.Request.Method == http.MethodGet && fc == 0 { + failCount = fc + 1 + muFail.Unlock() + // First request fails after 3000 bytes. + resp.Body = testutil.NewFlakyReader( + bytes.NewReader(payload), + int64(len(payload)), + 3000, + ) + return + } + muFail.Unlock() } - t.Cleanup(func() { resp.Body.Close() }) - // Assert Content-Length for initial 200 - if resp.ContentLength != int64(len(payload)) { - t.Errorf("ContentLength = %d, want %d", resp.ContentLength, len(payload)) + client := &http.Client{ + Transport: New(ft, WithMaxRetries(3)), + } + + resp, err := client.Get(url) + if err != nil { + t.Fatalf("GET failed: %v", err) } + defer resp.Body.Close() got, err := io.ReadAll(resp.Body) if err != nil { - t.Fatalf("read: %v", err) + t.Fatalf("read failed: %v", err) } - // Assert - if !bytes.Equal(got, payload) { - t.Errorf("payload mismatch: got %d bytes, want %d", len(got), len(payload)) + testutil.AssertDataEquals(t, got, payload) + + // Check If-Range header on resume request. + headers := ft.GetRequestHeaders(url) + foundIfRange := false + for _, h := range headers { + if h.Get("Range") != "" { + if ifRange := h.Get("If-Range"); ifRange == etag { + foundIfRange = true + break + } + } + } + + if !foundIfRange { + t.Error("expected If-Range header with ETag on resume") } } -// TestIfRange_ETag_ChangedOnResume_RejectsResume verifies that if the server -// changes its ETag between the initial response and the resume request, the -// client's If-Range will not match and the resume will be rejected. +// TestIfRange_ETag_ChangedOnResume_RejectsResume tests ETag change detection. func TestIfRange_ETag_ChangedOnResume_RejectsResume(t *testing.T) { - // Arrange - url := "https://example.com/ifrange-etag-changed" - payload := bytes.Repeat([]byte("X"), 6000) - ft := newFakeTransport() - ft.add(url, payload, &flakePlan{CutAfter: []int{1000, -1}, RequireIfRange: true, ChangeETagOnResume: true}) - - // Act - client := newClient(ft, 2) + url := "https://example.com/etag-changed" + payload := testutil.GenerateTestData(5000) + originalETag := `"original"` + changedETag := `"changed"` + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + ETag: originalETag, + }) + + // Change ETag on resume attempt. + ft.ResponseHook = func(resp *http.Response) { + if resp.Request.Header.Get("Range") != "" { + // Simulate resource change. + resp.StatusCode = http.StatusOK + resp.Status = "200 OK" + resp.Header.Set("ETag", changedETag) + resp.Header.Del("Content-Range") + resp.Body = io.NopCloser(testutil.NewFlakyReader( + bytes.NewReader(payload), + int64(len(payload)), + 0, + )) + } + } + + ft.SetFailAfter(url, 2500) + + client := &http.Client{ + Transport: New(ft, WithMaxRetries(3)), + } + resp, err := client.Get(url) if err != nil { t.Fatalf("GET: %v", err) } - t.Cleanup(func() { resp.Body.Close() }) + defer resp.Body.Close() - // Assert Content-Length for initial 200 - if resp.ContentLength != int64(len(payload)) { - t.Errorf("ContentLength = %d, want %d", resp.ContentLength, len(payload)) + _, err = io.ReadAll(resp.Body) + if err == nil || + err.Error() != "resumable: server returned 200 to a range request; resource may have changed" { + t.Errorf("expected resource change error, got: %v", err) } +} - _, rerr := io.ReadAll(resp.Body) +// TestIfRange_LastModified_Matches_AllowsResume tests If-Range with Last-Modified +func TestIfRange_LastModified_Matches_AllowsResume(t *testing.T) { + url := "https://example.com/if-range-lm" + payload := testutil.GenerateTestData(6000) + lastModified := "Wed, 21 Oct 2015 07:28:00 GMT" + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + LastModified: lastModified, + // No ETag, so should use Last-Modified + }) + + // Simulate failure + ft.SetFailAfter(url, 3000) - // Assert - if rerr == nil { - t.Errorf("expected read error due to If-Range (ETag) mismatch causing non-206 on resume") + client := &http.Client{ + Transport: New(ft, WithMaxRetries(3)), } -} -// TestIfRange_LastModified_Matches_AllowsResume verifies that when the server -// omits ETag but provides Last-Modified, the client uses Last-Modified as the -// If-Range validator and successfully resumes. -func TestIfRange_LastModified_Matches_AllowsResume(t *testing.T) { - // Arrange - url := "https://example.com/ifrange-lm-ok" - payload := bytes.Repeat([]byte("LMOK"), 3000) // 12,000 bytes - ft := newFakeTransport() - ft.add(url, payload, &flakePlan{CutAfter: []int{1500, -1}, RequireIfRange: true, OmitETag: true, OmitLastModified: false}) - - // Act - client := newClient(ft, 3) resp, err := client.Get(url) if err != nil { - t.Fatalf("GET: %v", err) - } - t.Cleanup(func() { resp.Body.Close() }) - - // Assert Content-Length for initial 200 - if resp.ContentLength != int64(len(payload)) { - t.Errorf("ContentLength = %d, want %d", resp.ContentLength, len(payload)) + t.Fatalf("GET failed: %v", err) } + defer resp.Body.Close() got, err := io.ReadAll(resp.Body) if err != nil { - t.Fatalf("read: %v", err) + t.Fatalf("read failed: %v", err) } - // Assert - if !bytes.Equal(got, payload) { - t.Errorf("payload mismatch: got %d bytes, want %d", len(got), len(payload)) + testutil.AssertDataEquals(t, got, payload) + + // Check If-Range uses Last-Modified + headers := ft.GetRequestHeaders(url) + foundIfRange := false + for _, h := range headers { + if h.Get("Range") != "" { + if ifRange := h.Get("If-Range"); ifRange == lastModified { + foundIfRange = true + break + } + } + } + + if !foundIfRange { + t.Error("expected If-Range header with Last-Modified on resume") } } -// TestIfRange_LastModified_ChangedOnResume_RejectsResume verifies that if the -// server changes its Last-Modified timestamp between initial and resume, the -// client's If-Range will not match and the resume will be rejected. +// TestIfRange_LastModified_ChangedOnResume_RejectsResume tests Last-Modified change detection func TestIfRange_LastModified_ChangedOnResume_RejectsResume(t *testing.T) { - // Arrange - url := "https://example.com/ifrange-lm-changed" - payload := bytes.Repeat([]byte{0xAB, 0xCD}, 5000) // 10,000 bytes - ft := newFakeTransport() - ft.add(url, payload, &flakePlan{CutAfter: []int{800, -1}, RequireIfRange: true, OmitETag: true, OmitLastModified: false, ChangeLastModifiedOnResume: true}) - - // Act - client := newClient(ft, 2) + url := "https://example.com/lm-changed" + payload := testutil.GenerateTestData(5000) + originalLM := "Wed, 21 Oct 2015 07:28:00 GMT" + changedLM := "Thu, 22 Oct 2015 08:30:00 GMT" + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + LastModified: originalLM, + }) + + // Change Last-Modified on resume + ft.ResponseHook = func(resp *http.Response) { + if resp.Request.Header.Get("Range") != "" { + // Simulate resource change + resp.StatusCode = http.StatusOK + resp.Status = "200 OK" + resp.Header.Set("Last-Modified", changedLM) + resp.Header.Del("Content-Range") + resp.Body = io.NopCloser(testutil.NewFlakyReader( + bytes.NewReader(payload), + int64(len(payload)), + 0, + )) + } + } + + ft.SetFailAfter(url, 2500) + + client := &http.Client{ + Transport: New(ft, WithMaxRetries(3)), + } + resp, err := client.Get(url) if err != nil { t.Fatalf("GET: %v", err) } - t.Cleanup(func() { resp.Body.Close() }) + defer resp.Body.Close() - // Assert Content-Length for initial 200 - if resp.ContentLength != int64(len(payload)) { - t.Errorf("ContentLength = %d, want %d", resp.ContentLength, len(payload)) + _, err = io.ReadAll(resp.Body) + if err == nil || + err.Error() != "resumable: server returned 200 to a range request; resource may have changed" { + t.Errorf("expected resource change error, got: %v", err) } +} + +// TestIfRange_RequiredButUnavailable_MissingRejected tests when no validator is available +func TestIfRange_RequiredButUnavailable_MissingRejected(t *testing.T) { + url := "https://example.com/no-validator" + payload := testutil.GenerateTestData(5000) - _, rerr := io.ReadAll(resp.Body) + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + // No ETag or LastModified + }) - // Assert - if rerr == nil { - t.Errorf("expected read error due to If-Range (Last-Modified) mismatch causing non-206 on resume") + ft.SetFailAfter(url, 2500) + + client := &http.Client{ + Transport: New(ft, WithMaxRetries(3)), } -} -// TestIfRange_RequiredButUnavailable_MissingRejected verifies that if the server -// requires If-Range but provides no validators at all, the client cannot form -// an If-Range and the resume will be rejected. -func TestIfRange_RequiredButUnavailable_MissingRejected(t *testing.T) { - // Arrange - url := "https://example.com/ifrange-missing" - payload := bytes.Repeat([]byte("no-validator"), 1000) - ft := newFakeTransport() - ft.add(url, payload, &flakePlan{CutAfter: []int{300, -1}, RequireIfRange: true, OmitETag: true, OmitLastModified: true}) - - // Act - client := newClient(ft, 2) resp, err := client.Get(url) if err != nil { t.Fatalf("GET: %v", err) } - t.Cleanup(func() { resp.Body.Close() }) + defer resp.Body.Close() - // Assert Content-Length for initial 200 - if resp.ContentLength != int64(len(payload)) { - t.Errorf("ContentLength = %d, want %d", resp.ContentLength, len(payload)) + _, err = io.ReadAll(resp.Body) + // Safer behavior: do not attempt resume without a validator. Expect an + // error to be surfaced when the initial stream fails and cannot resume. + if err == nil { + t.Error("expected error due to missing resume validator") } +} - _, rerr := io.ReadAll(resp.Body) +// TestIfRange_WeakETag_Present_UsesLastModified_AllowsResume tests weak ETags fall back to Last-Modified +func TestIfRange_WeakETag_Present_UsesLastModified_AllowsResume(t *testing.T) { + url := "https://example.com/weak-etag" + payload := testutil.GenerateTestData(10000) + lastModified := "Mon, 02 Jan 2006 15:04:05 MST" - // Assert - if rerr == nil { - t.Errorf("expected read error because server required If-Range but provided no validators") + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + ETag: `W/"weak-etag"`, // Weak ETag + LastModified: lastModified, + }) + + // Simulate failure + ft.SetFailAfter(url, 5000) + + client := &http.Client{ + Transport: New(ft, WithMaxRetries(3)), } -} -// TestIfRange_WeakETag_Present_UsesLastModified_AllowsResume verifies that when -// the server advertises a WEAK ETag and also a Last-Modified timestamp, the -// client will ignore the weak ETag for If-Range, use Last-Modified instead, and -// the resume will succeed. -func TestIfRange_WeakETag_Present_UsesLastModified_AllowsResume(t *testing.T) { - // Arrange: resource that advertises weak ETag + LM, requires If-Range, and cuts once. - url := "https://example.com/ifrange-weak-etag" - payload := bytes.Repeat([]byte("WEAK"), 2500) - ft := newFakeTransport() - ft.add(url, payload, &flakePlan{CutAfter: []int{600, -1}, RequireIfRange: true}) - // Override strong ETag with a WEAK one. - ft.mu.Lock() - ft.etags[url] = `W/"weak-` + strings.ReplaceAll(url, "/", "_") + `"` - ft.mu.Unlock() - - // Act: client should send If-Range with Last-Modified (not the weak ETag). - client := newClient(ft, 3) resp, err := client.Get(url) if err != nil { - t.Fatalf("GET: %v", err) - } - t.Cleanup(func() { resp.Body.Close() }) - - // Assert Content-Length for initial 200 - if resp.ContentLength != int64(len(payload)) { - t.Errorf("ContentLength = %d, want %d", resp.ContentLength, len(payload)) + t.Fatalf("GET failed: %v", err) } + defer resp.Body.Close() got, err := io.ReadAll(resp.Body) if err != nil { - t.Fatalf("read: %v", err) + t.Fatalf("read failed: %v", err) } - // Assert: full payload delivered successfully via resume. - if !bytes.Equal(got, payload) { - t.Errorf("payload mismatch: got %d bytes, want %d", len(got), len(payload)) + testutil.AssertDataEquals(t, got, payload) + + // Should use Last-Modified for If-Range, not weak ETag + headers := ft.GetRequestHeaders(url) + for _, h := range headers { + if h.Get("Range") != "" { + ifRange := h.Get("If-Range") + if ifRange == `W/"weak-etag"` { + t.Error("should not use weak ETag for If-Range") + } + if ifRange != lastModified { + t.Errorf("expected If-Range with Last-Modified, got %q", ifRange) + } + } } } -// TestGzipContentEncoding_DisablesResume verifies that when the initial 200 -// response has Content-Encoding set (e.g., gzip), the transport declines to -// wrap the body for resumption and thus a mid-stream failure bubbles up. +// TestGzipContentEncoding_DisablesResume tests that Content-Encoding disables resume func TestGzipContentEncoding_DisablesResume(t *testing.T) { - // Arrange: range-capable server that serves gzip on initial response and then cuts. - url := "https://example.com/gzip-initial" - payload := bytes.Repeat([]byte("zip"), 4000) - ft := newFakeTransport() - ft.add(url, payload, &flakePlan{CutAfter: []int{500}, InitialContentEncoding: "gzip"}) - - // Act: client uses resumable transport, but it should refuse to wrap due to encoding. - client := newClient(ft, 3) + url := "https://example.com/gzip" + payload := testutil.GenerateTestData(12000) + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + Headers: http.Header{ + "Content-Encoding": []string{"gzip"}, + }, + }) + + // Simulate failure + ft.SetFailAfter(url, 6000) + + client := &http.Client{ + Transport: New(ft, WithMaxRetries(3)), + } + resp, err := client.Get(url) if err != nil { - t.Fatalf("GET: %v", err) + t.Fatalf("GET failed: %v", err) } - t.Cleanup(func() { resp.Body.Close() }) + defer resp.Body.Close() - // Assert Content-Length for initial 200 (server advertises total) - if resp.ContentLength != int64(len(payload)) { - t.Errorf("ContentLength = %d, want %d", resp.ContentLength, len(payload)) + got, err := io.ReadAll(resp.Body) + // Should fail because Content-Encoding prevents resume + if err == nil { + t.Error("expected error due to Content-Encoding preventing resume") } - _, rerr := io.ReadAll(resp.Body) - - // Assert: we see an error because no resume was attempted under compression. - if rerr == nil { - t.Errorf("expected mid-stream error when initial response is compressed (no resume)") + // Should only have partial data + if len(got) >= len(payload) { + t.Errorf("got %d bytes, expected less due to failure", len(got)) } } -// TestResumeHeaders_ScrubbedAndIdentityEncoding verifies that on resume the client -// sets Accept-Encoding to identity and scrubs conditional headers that could -// conflict with If-Range semantics. +// TestResumeHeaders_ScrubbedAndIdentityEncoding tests header handling on resume func TestResumeHeaders_ScrubbedAndIdentityEncoding(t *testing.T) { - // Arrange: server supports ranges and will cut to force a resume. - url := "https://example.com/header-scrub" - payload := bytes.Repeat([]byte("H"), 4000) - ft := newFakeTransport() - ft.add(url, payload, &flakePlan{CutAfter: []int{600, -1}}) - - client := newClient(ft, 3) - // Build initial request with headers that should be scrubbed on resume. + url := "https://example.com/headers" + payload := testutil.GenerateTestData(5000) + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + ETag: `"test"`, + }) + + // Check headers on resume + ft.RequestHook = func(req *http.Request) { + if req.Header.Get("Range") != "" { + // Check that Accept-Encoding is set to identity + if ae := req.Header.Get("Accept-Encoding"); ae != "identity" { + t.Errorf("expected Accept-Encoding: identity, got: %q", ae) + } + // Check that conditional headers are removed + if req.Header.Get("If-Modified-Since") != "" { + t.Error("If-Modified-Since should be removed on resume") + } + if req.Header.Get("If-None-Match") != "" { + t.Error("If-None-Match should be removed on resume") + } + } + } + + ft.SetFailAfter(url, 2500) + + client := &http.Client{ + Transport: New(ft, WithMaxRetries(3)), + } + + // Create request with various headers req, err := http.NewRequest("GET", url, nil) if err != nil { - t.Fatalf("new request: %v", err) + t.Fatalf("create request: %v", err) } - req.Header.Set("Accept-Encoding", "gzip") // will be overridden to identity on resume - req.Header.Set("If-None-Match", "\"foo\"") - req.Header.Set("If-Modified-Since", time.Unix(1_600_000_000, 0).UTC().Format(http.TimeFormat)) - req.Header.Set("If-Match", "\"bar\"") - req.Header.Set("If-Unmodified-Since", time.Unix(1_600_000_100, 0).UTC().Format(http.TimeFormat)) + req.Header.Set("Accept-Encoding", "gzip, deflate") + req.Header.Set("If-Modified-Since", "Wed, 21 Oct 2015 07:28:00 GMT") + req.Header.Set("If-None-Match", `"other"`) - // Act: perform request and read to completion (triggering a resume once). resp, err := client.Do(req) if err != nil { - t.Fatalf("Do: %v", err) + t.Fatalf("GET: %v", err) } - t.Cleanup(func() { resp.Body.Close() }) + defer resp.Body.Close() - // Assert Content-Length for initial 200 - if resp.ContentLength != int64(len(payload)) { - t.Errorf("ContentLength = %d, want %d", resp.ContentLength, len(payload)) + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read: %v", err) } - _, _ = io.ReadAll(resp.Body) + testutil.AssertDataEquals(t, got, payload) +} - // Assert: fetch recorded headers for each segment. - hs := ft.segmentHeaders(url) - if len(hs) < 2 { - t.Fatalf("expected at least 2 segments (initial + resume), got %d", len(hs)) +// TestRangeRequest_Initial tests resume with initial Range request +func TestRangeRequest_Initial(t *testing.T) { + url := "https://example.com/range-initial" + payload := testutil.GenerateTestData(10240) + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + ETag: `"range-test"`, + }) + + // Simulate failure on range request + failCount := 0 + var muRange sync.Mutex + ft.ResponseHook = func(resp *http.Response) { + muRange.Lock() + fc := failCount + if resp.Request.Header.Get("Range") == "bytes=1024-5119" && fc == 0 { + failCount = fc + 1 + muRange.Unlock() + // Fail after 2000 bytes of the range + rangeData := payload[1024:5120] + resp.Body = testutil.NewFlakyReader( + bytes.NewReader(rangeData), + int64(len(rangeData)), + 2000, + ) + return + } + muRange.Unlock() } - initH, resumeH := hs[0], hs[1] - // Initial request kept our original Accept-Encoding; resume must be identity. - if got := strings.ToLower(resumeH.Get("Accept-Encoding")); got != "identity" { - t.Errorf("resume Accept-Encoding = %q, want %q", got, "identity") + // Create request with initial Range header + req, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Fatalf("create request: %v", err) } + req.Header.Set("Range", "bytes=1024-5119") - // Conditional headers must be scrubbed on resume. - condKeys := []string{"If-None-Match", "If-Modified-Since", "If-Match", "If-Unmodified-Since"} - for _, k := range condKeys { - if v := resumeH.Get(k); v != "" { - t.Errorf("resume header %s = %q, want empty", k, v) - } + client := &http.Client{ + Transport: New(ft, WithMaxRetries(3)), } - // Sanity: they were present on the initial request to prove scrubbing happened. - if initH.Get("If-None-Match") == "" || initH.Get("If-Modified-Since") == "" || initH.Get("If-Match") == "" || initH.Get("If-Unmodified-Since") == "" { - t.Errorf("expected conditional headers on initial request for comparison") + resp, err := client.Do(req) + if err != nil { + t.Fatalf("GET failed: %v", err) } + defer resp.Body.Close() - // Range and If-Range should be present on resume. - if r := resumeH.Get("Range"); r == "" || !strings.HasPrefix(strings.ToLower(r), "bytes=") { - t.Errorf("resume Range missing/invalid: %q", r) + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read failed: %v", err) } - if ir := resumeH.Get("If-Range"); ir == "" { - t.Errorf("resume If-Range missing") + + want := payload[1024:5120] + testutil.AssertDataEquals(t, got, want) + + // Check resume happened with adjusted range + headers := ft.GetRequestHeaders(url) + foundResume := false + for _, h := range headers { + rangeHeader := h.Get("Range") + if rangeHeader != "" && rangeHeader != "bytes=1024-5119" { + foundResume = true + t.Logf("Resume range: %s", rangeHeader) + } } -} -// ─────────────────────────────── Initial-Range tests ─────────────────────────────── + if !foundResume { + t.Error("expected resume with adjusted range") + } +} -// TestRangeInitial_ZeroToN_NoCuts_Succeeds verifies that when the *initial* request -// specifies a Range from 0..N, the transport delivers exactly that slice without -// any failures or resumes. +// Additional range request tests for comprehensive coverage func TestRangeInitial_ZeroToN_NoCuts_Succeeds(t *testing.T) { - // Arrange url := "https://example.com/range-0-n" - payload := bytes.Repeat([]byte("0123456789"), 1024) // 10,240 bytes - N := int64(2047) - ft := newFakeTransport() - ft.add(url, payload, &flakePlan{CutAfter: []int{-1}}) + payload := testutil.GenerateTestData(5000) - // Act: initial request is a Range request - client := newClient(ft, 3) - req, _ := http.NewRequest("GET", url, nil) - req.Header.Set("Range", "bytes=0-"+strconv.FormatInt(N, 10)) - resp, err := client.Do(req) + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + }) + + req, err := http.NewRequest("GET", url, nil) if err != nil { - t.Fatalf("Do: %v", err) + t.Fatal(err) } - t.Cleanup(func() { resp.Body.Close() }) + req.Header.Set("Range", "bytes=0-2499") - // Assert Content-Length for 206 (0..N inclusive) - if resp.ContentLength != (N + 1) { - t.Errorf("ContentLength = %d, want %d", resp.ContentLength, N+1) + client := &http.Client{Transport: New(ft)} + + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) } + defer resp.Body.Close() got, err := io.ReadAll(resp.Body) - if err != nil && err != io.EOF { - t.Fatalf("read: %v", err) + if err != nil { + t.Fatal(err) } - // Assert - want := payload[0 : N+1] - if !bytes.Equal(got, want) { - t.Errorf("payload mismatch: got %d bytes, want %d", len(got), len(want)) - } + want := payload[0:2500] + testutil.AssertDataEquals(t, got, want) } -// TestRangeInitial_MidSpan_NoCuts_Succeeds verifies a Range N..M (mid-file) -// succeeds without any resumes and matches the exact slice. func TestRangeInitial_MidSpan_NoCuts_Succeeds(t *testing.T) { - // Arrange - url := "https://example.com/range-n-m" - payload := bytes.Repeat([]byte("ABCDEFGH"), 2048) // 16,384 bytes - N := int64(500) - M := int64(3499) - ft := newFakeTransport() - ft.add(url, payload, &flakePlan{CutAfter: []int{-1}}) - - // Act - client := newClient(ft, 3) - req, _ := http.NewRequest("GET", url, nil) - req.Header.Set("Range", "bytes="+strconv.FormatInt(N, 10)+"-"+strconv.FormatInt(M, 10)) - resp, err := client.Do(req) + url := "https://example.com/range-mid" + payload := testutil.GenerateTestData(5000) + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + }) + + req, err := http.NewRequest("GET", url, nil) if err != nil { - t.Fatalf("Do: %v", err) + t.Fatal(err) } - t.Cleanup(func() { resp.Body.Close() }) + req.Header.Set("Range", "bytes=1000-1999") + + client := &http.Client{Transport: New(ft)} - // Assert Content-Length for 206 (N..M inclusive) - wantCL := (M - N + 1) - if resp.ContentLength != wantCL { - t.Errorf("ContentLength = %d, want %d", resp.ContentLength, wantCL) + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) } + defer resp.Body.Close() got, err := io.ReadAll(resp.Body) - if err != nil && err != io.EOF { - t.Fatalf("read: %v", err) + if err != nil { + t.Fatal(err) } - // Assert - want := payload[N : M+1] - if !bytes.Equal(got, want) { - t.Errorf("payload mismatch: got %d bytes, want %d", len(got), len(want)) - } + want := payload[1000:2000] + testutil.AssertDataEquals(t, got, want) } -// TestRangeInitial_FromNToEnd_NoCuts_Succeeds verifies a Range N..end request -// ("bytes=N-") succeeds and returns the tail of the object. func TestRangeInitial_FromNToEnd_NoCuts_Succeeds(t *testing.T) { - // Arrange - url := "https://example.com/range-n-end" - payload := bytes.Repeat([]byte("xyz"), 5000) // 15,000 bytes - N := int64(2500) - ft := newFakeTransport() - ft.add(url, payload, &flakePlan{CutAfter: []int{-1}}) - - // Act - client := newClient(ft, 3) - req, _ := http.NewRequest("GET", url, nil) - req.Header.Set("Range", "bytes="+strconv.FormatInt(N, 10)+"-") - resp, err := client.Do(req) + url := "https://example.com/range-to-end" + payload := testutil.GenerateTestData(5000) + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + }) + + req, err := http.NewRequest("GET", url, nil) if err != nil { - t.Fatalf("Do: %v", err) + t.Fatal(err) } - t.Cleanup(func() { resp.Body.Close() }) + req.Header.Set("Range", "bytes=3000-") - // Assert Content-Length for 206 (N..end inclusive) - wantCL := int64(len(payload)) - N - if resp.ContentLength != wantCL { - t.Errorf("ContentLength = %d, want %d", resp.ContentLength, wantCL) + client := &http.Client{Transport: New(ft)} + + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) } + defer resp.Body.Close() got, err := io.ReadAll(resp.Body) - if err != nil && err != io.EOF { - t.Fatalf("read: %v", err) + if err != nil { + t.Fatal(err) } - // Assert - want := payload[N:] - if !bytes.Equal(got, want) { - t.Errorf("payload mismatch: got %d bytes, want %d", len(got), len(want)) - } + want := payload[3000:] + testutil.AssertDataEquals(t, got, want) } -// TestRangeInitial_ZeroToN_WithCut_Resumes verifies that a Range 0..N with a -// mid-stream cut resumes correctly and still yields the exact slice. func TestRangeInitial_ZeroToN_WithCut_Resumes(t *testing.T) { - // Arrange url := "https://example.com/range-0-n-cut" - payload := bytes.Repeat([]byte("Q"), 9000) - N := int64(4095) - ft := newFakeTransport() - ft.add(url, payload, &flakePlan{CutAfter: []int{512, -1}}) // cut during initial segment + payload := testutil.GenerateTestData(5000) + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + ETag: `"test"`, + }) + + // Fail the range request partway through + failCount := 0 + ft.ResponseHook = func(resp *http.Response) { + if resp.Request.Header.Get("Range") == "bytes=0-2499" && failCount == 0 { + failCount++ + resp.Body = testutil.NewFlakyReader( + bytes.NewReader(payload[0:2500]), + int64(len(payload[0:2500])), + 1000, + ) + } + } - // Act - client := newClient(ft, 4) - req, _ := http.NewRequest("GET", url, nil) - req.Header.Set("Range", "bytes=0-"+strconv.FormatInt(N, 10)) - resp, err := client.Do(req) + req, err := http.NewRequest("GET", url, nil) if err != nil { - t.Fatalf("Do: %v", err) + t.Fatal(err) } - t.Cleanup(func() { resp.Body.Close() }) + req.Header.Set("Range", "bytes=0-2499") - // Assert Content-Length for 206 (0..N inclusive) - if resp.ContentLength != (N + 1) { - t.Errorf("ContentLength = %d, want %d", resp.ContentLength, N+1) + client := &http.Client{Transport: New(ft, WithMaxRetries(3))} + + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) } + defer resp.Body.Close() got, err := io.ReadAll(resp.Body) - if err != nil && err != io.EOF { - t.Fatalf("read: %v", err) + if err != nil { + t.Fatal(err) + } + + want := payload[0:2500] + testutil.AssertDataEquals(t, got, want) + + // Verify resume happened + headers := ft.GetRequestHeaders(url) + foundResume := false + for _, h := range headers { + rangeHeader := h.Get("Range") + if rangeHeader != "" && rangeHeader != "bytes=0-2499" { + foundResume = true + if rangeHeader != "bytes=1000-2499" { + t.Errorf("expected resume at bytes=1000-2499, got: %s", rangeHeader) + } + } } - // Assert - want := payload[:N+1] - if !bytes.Equal(got, want) { - t.Errorf("payload mismatch: got %d bytes, want %d", len(got), len(want)) + if !foundResume { + t.Error("expected resume") } } -// TestRangeInitial_MidSpan_WithMultipleCuts_Resumes verifies a Range N..M with -// multiple failures is properly reassembled within the retry budget. func TestRangeInitial_MidSpan_WithMultipleCuts_Resumes(t *testing.T) { - // Arrange - url := "https://example.com/range-n-m-cuts" - payload := bytes.Repeat([]byte{0xAA, 0xBB, 0xCC, 0xDD}, 5000) // 20,000 bytes - N := int64(1000) - M := int64(9999) - ft := newFakeTransport() - ft.add(url, payload, &flakePlan{CutAfter: []int{300, 400, -1}}) // two cuts, then ok - - // Act - client := newClient(ft, 5) - req, _ := http.NewRequest("GET", url, nil) - req.Header.Set("Range", "bytes="+strconv.FormatInt(N, 10)+"-"+strconv.FormatInt(M, 10)) - resp, err := client.Do(req) + url := "https://example.com/range-mid-cuts" + payload := testutil.GenerateTestData(10000) + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + ETag: `"test"`, + }) + + // Multiple failures on the range request + failCount := 0 + var muCut sync.Mutex + ft.ResponseHook = func(resp *http.Response) { + rangeHeader := resp.Request.Header.Get("Range") + muCut.Lock() + fc := failCount + if rangeHeader == "bytes=2000-5999" && fc == 0 { + failCount = fc + 1 + muCut.Unlock() + resp.Body = testutil.NewFlakyReader( + bytes.NewReader(payload[2000:6000]), + int64(len(payload[2000:6000])), + 1000, + ) + return + } else if rangeHeader == "bytes=3000-5999" && fc == 1 { + failCount = fc + 1 + muCut.Unlock() + resp.Body = testutil.NewFlakyReader( + bytes.NewReader(payload[3000:6000]), + int64(len(payload[3000:6000])), + 1500, + ) + return + } + muCut.Unlock() + } + + req, err := http.NewRequest("GET", url, nil) if err != nil { - t.Fatalf("Do: %v", err) + t.Fatal(err) } - t.Cleanup(func() { resp.Body.Close() }) + req.Header.Set("Range", "bytes=2000-5999") + + client := &http.Client{Transport: New(ft, WithMaxRetries(5))} - // Assert Content-Length for 206 (N..M inclusive) - wantCL := (M - N + 1) - if resp.ContentLength != wantCL { - t.Errorf("ContentLength = %d, want %d", resp.ContentLength, wantCL) + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) } + defer resp.Body.Close() got, err := io.ReadAll(resp.Body) - if err != nil && err != io.EOF { - t.Fatalf("read: %v", err) + if err != nil { + t.Fatal(err) + } + + want := payload[2000:6000] + testutil.AssertDataEquals(t, got, want) + + // Check that multiple resumes happened. + reqs := ft.GetRequests() + var rangeCount int + for _, r := range reqs { + if r.Header.Get("Range") != "" { + rangeCount++ + } } - // Assert - want := payload[N : M+1] - if !bytes.Equal(got, want) { - t.Errorf("payload mismatch: got %d bytes, want %d", len(got), len(want)) + if rangeCount < 3 { + t.Errorf("expected at least 3 range requests, got %d", rangeCount) } } -// TestRangeInitial_FromNToEnd_WithCut_Resumes verifies a Range N..end request -// resumes correctly after a mid-stream failure and returns the tail of the object. func TestRangeInitial_FromNToEnd_WithCut_Resumes(t *testing.T) { - // Arrange - url := "https://example.com/range-n-end-cut" - payload := bytes.Repeat([]byte("tail"), 6000) // 24,000 bytes - N := int64(7777) - ft := newFakeTransport() - ft.add(url, payload, &flakePlan{CutAfter: []int{1024, -1}}) - - // Act - client := newClient(ft, 3) - req, _ := http.NewRequest("GET", url, nil) - req.Header.Set("Range", "bytes="+strconv.FormatInt(N, 10)+"-") - resp, err := client.Do(req) + url := "https://example.com/range-to-end-cut" + payload := testutil.GenerateTestData(10000) + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + ETag: `"test"`, + }) + + // Fail the open-ended range request + failCount := 0 + var muOpen sync.Mutex + ft.ResponseHook = func(resp *http.Response) { + muOpen.Lock() + fc := failCount + if resp.Request.Header.Get("Range") == "bytes=7000-" && fc == 0 { + failCount = fc + 1 + muOpen.Unlock() + resp.Body = testutil.NewFlakyReader( + bytes.NewReader(payload[7000:]), + int64(len(payload[7000:])), + 1500, + ) + return + } + muOpen.Unlock() + } + + req, err := http.NewRequest("GET", url, nil) if err != nil { - t.Fatalf("Do: %v", err) + t.Fatal(err) } - t.Cleanup(func() { resp.Body.Close() }) + req.Header.Set("Range", "bytes=7000-") + + client := &http.Client{Transport: New(ft, WithMaxRetries(3))} - // Assert Content-Length for 206 (N..end inclusive) - wantCL := int64(len(payload)) - N - if resp.ContentLength != wantCL { - t.Errorf("ContentLength = %d, want %d", resp.ContentLength, wantCL) + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) } + defer resp.Body.Close() got, err := io.ReadAll(resp.Body) - if err != nil && err != io.EOF { - t.Fatalf("read: %v", err) + if err != nil { + t.Fatal(err) + } + + want := payload[7000:] + testutil.AssertDataEquals(t, got, want) + + // Verify resume happened + headers := ft.GetRequestHeaders(url) + foundResume := false + for _, h := range headers { + rangeHeader := h.Get("Range") + if rangeHeader != "" && rangeHeader != "bytes=7000-" { + foundResume = true + // Accept either open-ended or closed range + if rangeHeader != "bytes=8500-" && rangeHeader != "bytes=8500-9999" { + t.Errorf("expected resume at bytes=8500- or bytes=8500-9999, got: %s", rangeHeader) + } + } } - // Assert - want := payload[N:] - if !bytes.Equal(got, want) { - t.Errorf("payload mismatch: got %d bytes, want %d", len(got), len(want)) + if !foundResume { + t.Error("expected resume") } } -// TestRangeInitial_ResumeHeaderStart_Correct asserts that the resume request's -// Range header starts exactly at initialStart + bytesDelivered. func TestRangeInitial_ResumeHeaderStart_Correct(t *testing.T) { - // Arrange: Range 0..2047 with a cut after 512 bytes on initial segment. url := "https://example.com/range-header-check" - payload := bytes.Repeat([]byte("H"), 4096) - N := int64(2047) - cut := 512 - ft := newFakeTransport() - ft.add(url, payload, &flakePlan{CutAfter: []int{cut, -1}}) - - client := newClient(ft, 3) - req, _ := http.NewRequest("GET", url, nil) - req.Header.Set("Range", "bytes=0-"+strconv.FormatInt(N, 10)) - - // Act: perform the request and read to completion (forcing one resume). - resp, err := client.Do(req) - if err != nil { - t.Fatalf("Do: %v", err) + payload := testutil.GenerateTestData(5000) + + ft := testutil.NewFakeTransport() + ft.Add(url, &testutil.FakeResource{ + Data: bytes.NewReader(payload), + Length: int64(len(payload)), + SupportsRange: true, + ETag: `"test"`, + }) + + // Fail at exactly 1234 bytes + failCount := 0 + var muHdr sync.Mutex + ft.ResponseHook = func(resp *http.Response) { + muHdr.Lock() + fc := failCount + if resp.Request.Header.Get("Range") == "bytes=1000-2999" && fc == 0 { + failCount = fc + 1 + muHdr.Unlock() + rangeData := payload[1000:3000] + resp.Body = testutil.NewFlakyReader( + bytes.NewReader(rangeData), + int64(len(rangeData)), + 234, + ) // Will have read 1234 total + return + } + muHdr.Unlock() } - t.Cleanup(func() { resp.Body.Close() }) - // Assert Content-Length for 206 (0..N inclusive) - if resp.ContentLength != (N + 1) { - t.Errorf("ContentLength = %d, want %d", resp.ContentLength, N+1) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Fatal(err) } + req.Header.Set("Range", "bytes=1000-2999") - _, _ = io.ReadAll(resp.Body) + client := &http.Client{Transport: New(ft, WithMaxRetries(3))} - // Assert: check second segment's Range header. - hs := ft.segmentHeaders(url) - if len(hs) < 2 { - t.Fatalf("expected at least 2 segments (initial + resume), got %d", len(hs)) - } - resumeRange := hs[1].Get("Range") - want := "bytes=" + strconv.FormatInt(int64(cut), 10) + "-" + strconv.FormatInt(N, 10) - if resumeRange != want { - t.Errorf("resume Range header = %q, want %q", resumeRange, want) + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) } -} + defer resp.Body.Close() -// ─────────────────────────────── Parser tests ─────────────────────────────── - -// TestParseSingleRange exercises valid and invalid single-range specs. -func TestParseSingleRange(t *testing.T) { - cases := []struct { - in string - start, end int64 - ok bool - }{ - {"", 0, -1, false}, - {"bytes=0-99", 0, 99, true}, - {"bytes=0-", 0, -1, true}, - {"bytes=5-5", 5, 5, true}, - {"BYTES=7-9", 7, 9, true}, - {"bytes=10-5", 0, -1, false}, // end before start - {"bytes=-100", 0, -1, false}, // suffix not supported - {"items=0-10", 0, -1, false}, - {"bytes=0-1,3-5", 0, -1, false}, // multi-range unsupported - } - for _, tc := range cases { - start, end, ok := parseSingleRange(tc.in) - if start != tc.start || end != tc.end || ok != tc.ok { - t.Errorf("parseSingleRange(%q) = (%d,%d,%v), want (%d,%d,%v)", tc.in, start, end, ok, tc.start, tc.end, tc.ok) - } + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) } -} -// TestParseContentRange exercises valid and invalid Content-Range headers. -func TestParseContentRange(t *testing.T) { - cases := []struct { - in string - start, end int64 - total int64 - ok bool - }{ - {"", 0, -1, -1, false}, - {"bytes 0-99/200", 0, 99, 200, true}, - {"BYTES 1-1/2", 1, 1, 2, true}, - {"bytes 0-0/*", 0, 0, -1, true}, - {"items 0-1/2", 0, -1, -1, false}, - {"bytes 0-99/abc", 0, -1, -1, false}, - {"bytes 5-4/10", 5, 4, 10, true}, // parser accepts; semantic check happens elsewhere - } - for _, tc := range cases { - start, end, total, ok := parseContentRange(tc.in) - if start != tc.start || end != tc.end || total != tc.total || ok != tc.ok { - t.Errorf("parseContentRange(%q) = (%d,%d,%d,%v), want (%d,%d,%d,%v)", tc.in, start, end, total, ok, tc.start, tc.end, tc.total, tc.ok) + want := payload[1000:3000] + testutil.AssertDataEquals(t, got, want) + + // Check the resume request has correct start position + headers := ft.GetRequestHeaders(url) + for _, h := range headers { + rangeHeader := h.Get("Range") + if rangeHeader != "" && rangeHeader != "bytes=1000-2999" { + if rangeHeader != "bytes=1234-2999" { + t.Errorf("expected resume at bytes=1234-2999, got: %s", rangeHeader) + } + break } } }