|
| 1 | +import argparse |
| 2 | +import asyncio |
| 3 | +import os |
| 4 | +import sys |
| 5 | +import time |
| 6 | +from typing import Optional |
| 7 | + |
| 8 | +import nats |
| 9 | +import nats.js.api as api |
| 10 | + |
| 11 | +try: |
| 12 | + import uvloop |
| 13 | + |
| 14 | + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) |
| 15 | +except ImportError: |
| 16 | + pass |
| 17 | + |
| 18 | +DEFAULT_NUM_FETCHES = 10 |
| 19 | +DEFAULT_TIMEOUT = 30 |
| 20 | +DEFAULT_BUCKET = "" |
| 21 | +DEFAULT_OBJECT = "" |
| 22 | + |
| 23 | + |
| 24 | +class ProgressFileWrapper: |
| 25 | + """ |
| 26 | + A file wrapper that shows download progress as data is written. |
| 27 | + """ |
| 28 | + |
| 29 | + def __init__(self, file_obj, total_size: int, object_name: str): |
| 30 | + self.file = file_obj |
| 31 | + self.total_size = total_size |
| 32 | + self.object_name = object_name |
| 33 | + self.bytes_written = 0 |
| 34 | + self.last_progress = -1 |
| 35 | + self.start_time = time.time() |
| 36 | + |
| 37 | + def write(self, data): |
| 38 | + """Write data to file and update progress.""" |
| 39 | + result = self.file.write(data) |
| 40 | + self.bytes_written += len(data) |
| 41 | + self._update_progress() |
| 42 | + return result |
| 43 | + |
| 44 | + def _update_progress(self): |
| 45 | + """Update progress display.""" |
| 46 | + if self.total_size <= 0: |
| 47 | + return |
| 48 | + |
| 49 | + progress = int((self.bytes_written / self.total_size) * 100) |
| 50 | + |
| 51 | + # Only update every 5% to avoid too much output |
| 52 | + if progress >= self.last_progress + 5: |
| 53 | + elapsed = time.time() - self.start_time |
| 54 | + if elapsed > 0: |
| 55 | + speed_mbps = (self.bytes_written / (1024 * 1024)) / elapsed |
| 56 | + mb_written = self.bytes_written / (1024 * 1024) |
| 57 | + mb_total = self.total_size / (1024 * 1024) |
| 58 | + |
| 59 | + # Clear the current line and show progress |
| 60 | + print( |
| 61 | + f"\r {self.object_name}: {progress:3d}% ({mb_written:.1f}/{mb_total:.1f} MB) @ {speed_mbps:.1f} MB/s", |
| 62 | + end="", |
| 63 | + flush=True, |
| 64 | + ) |
| 65 | + self.last_progress = progress |
| 66 | + |
| 67 | + def __getattr__(self, name): |
| 68 | + """Delegate other attributes to the wrapped file.""" |
| 69 | + return getattr(self.file, name) |
| 70 | + |
| 71 | + |
| 72 | +def show_usage(): |
| 73 | + message = """ |
| 74 | +Usage: obj_fetch_perf [options] |
| 75 | +
|
| 76 | +options: |
| 77 | + -n COUNT Number of fetches to perform (default: 10) |
| 78 | + -b BUCKET Object store bucket name (default: rethink) |
| 79 | + -o OBJECT Object name to fetch (default: RethinkConnectivity.mp4) |
| 80 | + -t TIMEOUT Timeout per fetch in seconds (default: 30) |
| 81 | + -f FILE Write to file (streaming mode, memory efficient) |
| 82 | + --overwrite Overwrite output file if it exists |
| 83 | + --servers SERVERS NATS server URLs (default: nats://demo.nats.io:4222) |
| 84 | + """ |
| 85 | + print(message) |
| 86 | + |
| 87 | + |
| 88 | +def show_usage_and_die(): |
| 89 | + show_usage() |
| 90 | + sys.exit(1) |
| 91 | + |
| 92 | + |
| 93 | +async def main(): |
| 94 | + parser = argparse.ArgumentParser() |
| 95 | + parser.add_argument("-n", "--count", default=DEFAULT_NUM_FETCHES, type=int) |
| 96 | + parser.add_argument("-b", "--bucket", default=DEFAULT_BUCKET) |
| 97 | + parser.add_argument("-o", "--object", default=DEFAULT_OBJECT) |
| 98 | + parser.add_argument("-t", "--timeout", default=DEFAULT_TIMEOUT, type=int) |
| 99 | + parser.add_argument("-f", "--file", help="Write to file (streaming mode)") |
| 100 | + parser.add_argument("--overwrite", action="store_true", help="Overwrite output file if it exists") |
| 101 | + parser.add_argument("--servers", default=[], action="append") |
| 102 | + args = parser.parse_args() |
| 103 | + |
| 104 | + servers = args.servers |
| 105 | + if len(args.servers) < 1: |
| 106 | + servers = ["nats://demo.nats.io:4222"] |
| 107 | + |
| 108 | + print(f"Connecting to NATS servers: {servers}") |
| 109 | + |
| 110 | + # Connect to NATS with JetStream |
| 111 | + try: |
| 112 | + nc = await nats.connect(servers, pending_size=1024 * 1024) |
| 113 | + js = nc.jetstream() |
| 114 | + except Exception as e: |
| 115 | + sys.stderr.write(f"ERROR: Failed to connect to NATS: {e}\n") |
| 116 | + show_usage_and_die() |
| 117 | + |
| 118 | + # Get object store |
| 119 | + try: |
| 120 | + obs = await js.object_store(bucket=args.bucket) |
| 121 | + print(f"Connected to object store bucket: {args.bucket}") |
| 122 | + except Exception as e: |
| 123 | + sys.stderr.write(f"ERROR: Failed to access object store bucket '{args.bucket}': {e}\n") |
| 124 | + await nc.close() |
| 125 | + sys.exit(1) |
| 126 | + |
| 127 | + # Get object info first to verify it exists and show stats |
| 128 | + try: |
| 129 | + info = await obs.get_info(args.object) |
| 130 | + size_mb = info.size / (1024 * 1024) |
| 131 | + print(f"Object: {args.object}") |
| 132 | + print(f"Size: {info.size} bytes ({size_mb:.2f} MB)") |
| 133 | + print(f"Chunks: {info.chunks}") |
| 134 | + print(f"Description: {info.description}") |
| 135 | + print() |
| 136 | + except Exception as e: |
| 137 | + sys.stderr.write(f"ERROR: Failed to get object info for '{args.object}': {e}\n") |
| 138 | + await nc.close() |
| 139 | + sys.exit(1) |
| 140 | + |
| 141 | + # Handle file output setup |
| 142 | + output_file = None |
| 143 | + if args.file: |
| 144 | + if os.path.exists(args.file) and not args.overwrite: |
| 145 | + sys.stderr.write(f"ERROR: File '{args.file}' already exists. Use --overwrite to replace it.\n") |
| 146 | + await nc.close() |
| 147 | + sys.exit(1) |
| 148 | + |
| 149 | + # For multiple fetches with file output, append a counter |
| 150 | + if args.count > 1: |
| 151 | + base, ext = os.path.splitext(args.file) |
| 152 | + print(f"Multiple fetches with file output - files will be named: {base}_1{ext}, {base}_2{ext}, etc.") |
| 153 | + else: |
| 154 | + print(f"Streaming output to file: {args.file}") |
| 155 | + print() |
| 156 | + |
| 157 | + # Start the benchmark |
| 158 | + print(f"Starting benchmark: fetching '{args.object}' {args.count} times") |
| 159 | + if args.file: |
| 160 | + print("Progress (streaming to file):") |
| 161 | + else: |
| 162 | + print("Progress: ", end="", flush=True) |
| 163 | + |
| 164 | + start = time.time() |
| 165 | + total_bytes = 0 |
| 166 | + successful_fetches = 0 |
| 167 | + failed_fetches = 0 |
| 168 | + |
| 169 | + for i in range(args.count): |
| 170 | + fetch_start = time.time() |
| 171 | + try: |
| 172 | + # Determine output file for this fetch |
| 173 | + current_file = None |
| 174 | + if args.file: |
| 175 | + if args.count > 1: |
| 176 | + base, ext = os.path.splitext(args.file) |
| 177 | + current_file = f"{base}_{i + 1}{ext}" |
| 178 | + else: |
| 179 | + current_file = args.file |
| 180 | + |
| 181 | + # Fetch the object |
| 182 | + if current_file: |
| 183 | + # Stream to file with progress tracking |
| 184 | + with open(current_file, "wb") as f: |
| 185 | + # Wrap the file with progress tracker |
| 186 | + progress_wrapper = ProgressFileWrapper(f, info.size, args.object) |
| 187 | + result = await asyncio.wait_for( |
| 188 | + obs.get(args.object, writeinto=progress_wrapper), timeout=args.timeout |
| 189 | + ) |
| 190 | + # Get file size for stats |
| 191 | + fetch_bytes = os.path.getsize(current_file) |
| 192 | + # Ensure we show 100% completion |
| 193 | + if progress_wrapper.bytes_written > 0: |
| 194 | + print( |
| 195 | + f"\r 📥 {args.object}: 100% ({fetch_bytes / (1024 * 1024):.1f}/{info.size / (1024 * 1024):.1f} MB) ✓" |
| 196 | + ) |
| 197 | + else: |
| 198 | + # Load into memory |
| 199 | + result = await asyncio.wait_for(obs.get(args.object), timeout=args.timeout) |
| 200 | + fetch_bytes = len(result.data) |
| 201 | + |
| 202 | + fetch_time = time.time() - fetch_start |
| 203 | + total_bytes += fetch_bytes |
| 204 | + successful_fetches += 1 |
| 205 | + |
| 206 | + # Show simple progress for in-memory mode |
| 207 | + if not current_file: |
| 208 | + print("#", end="", flush=True) |
| 209 | + |
| 210 | + except asyncio.TimeoutError: |
| 211 | + failed_fetches += 1 |
| 212 | + if args.file: |
| 213 | + print(f"\r ❌ {args.object}: Timeout after {args.timeout}s") |
| 214 | + else: |
| 215 | + print("T", end="", flush=True) # T for timeout |
| 216 | + except Exception as e: |
| 217 | + failed_fetches += 1 |
| 218 | + if args.file: |
| 219 | + print(f"\r ❌ {args.object}: Error - {str(e)[:50]}") |
| 220 | + else: |
| 221 | + print("E", end="", flush=True) # E for error |
| 222 | + if i == 0: # Show first error for debugging |
| 223 | + sys.stderr.write(f"\nFirst fetch error: {e}\n") |
| 224 | + |
| 225 | + # Small pause between fetches |
| 226 | + await asyncio.sleep(0.01) |
| 227 | + |
| 228 | + elapsed = time.time() - start |
| 229 | + |
| 230 | + print(f"\n\nBenchmark Results:") |
| 231 | + print(f"=================") |
| 232 | + if args.file: |
| 233 | + print(f"Mode: Streaming to file(s) (memory efficient)") |
| 234 | + else: |
| 235 | + print(f"Mode: In-memory loading") |
| 236 | + print(f"Total time: {elapsed:.2f} seconds") |
| 237 | + print(f"Successful fetches: {successful_fetches}/{args.count}") |
| 238 | + print(f"Failed fetches: {failed_fetches}") |
| 239 | + |
| 240 | + if successful_fetches > 0: |
| 241 | + avg_time = elapsed / successful_fetches |
| 242 | + mbytes_per_sec = (total_bytes / elapsed) / (1024 * 1024) |
| 243 | + fetches_per_sec = successful_fetches / elapsed |
| 244 | + |
| 245 | + print(f"Average fetch time: {avg_time:.3f} seconds") |
| 246 | + print(f"Fetches per second: {fetches_per_sec:.2f}") |
| 247 | + print(f"Throughput: {mbytes_per_sec:.2f} MB/sec") |
| 248 | + print(f"Total data transferred: {total_bytes / (1024 * 1024):.2f} MB") |
| 249 | + |
| 250 | + await nc.close() |
| 251 | + |
| 252 | + |
| 253 | +if __name__ == "__main__": |
| 254 | + asyncio.run(main()) |
0 commit comments