diff --git a/examples/clay_embeddings_example.py b/examples/clay_embeddings_example.py new file mode 100644 index 00000000..1a3c6bc5 --- /dev/null +++ b/examples/clay_embeddings_example.py @@ -0,0 +1,234 @@ +""" +Example script demonstrating Clay foundation model embeddings with segment-geospatial. + +This script shows how to: +1. Load a geospatial image +2. Generate Clay foundation model embeddings +3. Save and load embeddings +4. Visualize embedding results + +Requirements: +- Clay model checkpoint file +- Geospatial imagery (GeoTIFF, etc.) +- Clay model dependencies: claymodel, torch, torchvision, pyyaml, python-box +""" + +import os +import numpy as np +import matplotlib.pyplot as plt +from samgeo import Clay, load_embeddings + + +def main(): + # Configuration + CHECKPOINT_PATH = "path/to/clay-model-checkpoint.ckpt" # Update this path + IMAGE_PATH = "path/to/your/satellite_image.tif" # Update this path + OUTPUT_DIR = "clay_embeddings_output" + + # Create output directory + os.makedirs(OUTPUT_DIR, exist_ok=True) + + print("=== Clay Foundation Model Embeddings Example ===\n") + + # Step 1: Initialize Clay embeddings model + print("1. Initializing Clay model...") + try: + clay = Clay( + checkpoint_path=CHECKPOINT_PATH, + device="auto", # Will use GPU if available + mask_ratio=0.0, # No masking for inference + shuffle=False, + ) + print(" ✓ Clay model loaded successfully") + except Exception as e: + print(f" ✗ Error loading Clay model: {e}") + print(" Please ensure you have:") + print(" - Valid Clay checkpoint file") + print( + " - Clay dependencies: pip install claymodel torch torchvision pyyaml python-box" + ) + return + + # Step 2: Load and analyze image + print("\n2. Loading geospatial image...") + try: + clay.set_image( + source=IMAGE_PATH, + # sensor_type="sentinel-2-l2a", # Optional: override auto-detection + # date="2023-06-01", # Optional: specify acquisition date + # gsd_override=10.0 # Optional: override ground sample distance + ) + print(" ✓ Image loaded and analyzed") + print(f" - Image shape: {clay.image.shape}") + print(f" - Detected sensor: {clay.sensor_type}") + print(f" - Center coordinates: ({clay.lat:.4f}, {clay.lon:.4f})") + except Exception as e: + print(f" ✗ Error loading image: {e}") + print(" Please check the image path and format") + return + + # Step 3: Generate embeddings + print("\n3. Generating Clay embeddings...") + try: + # For large images, process in tiles + embeddings_result = clay.generate_embeddings( + tile_size=256, # Size of processing tiles + overlap=0.1, # 10% overlap between tiles + ) + + print(" ✓ Embeddings generated successfully") + print(f" - Number of tiles: {embeddings_result['num_tiles']}") + print(f" - Embedding shape: {embeddings_result['embeddings'].shape}") + print(f" - Feature dimension: {embeddings_result['embeddings'].shape[-1]}") + + except Exception as e: + print(f" ✗ Error generating embeddings: {e}") + return + + # Step 4: Save embeddings + print("\n4. Saving embeddings...") + try: + embeddings_file = os.path.join(OUTPUT_DIR, "clay_embeddings.npz") + clay.save_embeddings(embeddings_result, embeddings_file, format="npz") + print(f" ✓ Embeddings saved to {embeddings_file}") + except Exception as e: + print(f" ✗ Error saving embeddings: {e}") + return + + # Step 5: Load and verify embeddings + print("\n5. Loading and verifying saved embeddings...") + try: + loaded_embeddings = load_embeddings(embeddings_file) + print(" ✓ Embeddings loaded successfully") + print(f" - Sensor type: {loaded_embeddings['sensor_type']}") + print(f" - Number of tiles: {loaded_embeddings['num_tiles']}") + print(f" - Original image shape: {loaded_embeddings['image_shape']}") + except Exception as e: + print(f" ✗ Error loading embeddings: {e}") + return + + # Step 6: Visualize results + print("\n6. Creating visualizations...") + try: + # Plot RGB image if available + fig, axes = plt.subplots(1, 2, figsize=(15, 6)) + + # Original image (RGB bands if available) + image = clay.image + if clay.sensor_type in clay.metadata: + rgb_indices = clay.metadata[clay.sensor_type].get("rgb_indices", [0, 1, 2]) + if len(rgb_indices) == 3 and image.shape[2] >= max(rgb_indices) + 1: + rgb_image = image[:, :, rgb_indices] + # Normalize for display + rgb_image = np.clip(rgb_image / np.percentile(rgb_image, 98), 0, 1) + axes[0].imshow(rgb_image) + axes[0].set_title(f"Original Image ({clay.sensor_type})") + axes[0].axis("off") + else: + axes[0].imshow(image[:, :, 0], cmap="gray") + axes[0].set_title("Original Image (First Band)") + axes[0].axis("off") + else: + axes[0].imshow(image[:, :, 0], cmap="gray") + axes[0].set_title("Original Image (First Band)") + axes[0].axis("off") + + # Embedding visualization (PCA of first tile) + embeddings = embeddings_result["embeddings"] + if embeddings.shape[0] > 0: + # Use first embedding for visualization + first_embedding = embeddings[0].flatten() + + # Create a simple visualization of embedding values + embedding_2d = first_embedding[:256].reshape( + 16, 16 + ) # Take first 256 values + axes[1].imshow(embedding_2d, cmap="viridis") + axes[1].set_title( + "Clay Embedding Visualization\n(First 256 features, first tile)" + ) + axes[1].axis("off") + + plt.tight_layout() + + # Save plot + plot_file = os.path.join(OUTPUT_DIR, "clay_embeddings_visualization.png") + plt.savefig(plot_file, dpi=150, bbox_inches="tight") + plt.show() + + print(f" ✓ Visualization saved to {plot_file}") + + except Exception as e: + print(f" ✗ Error creating visualizations: {e}") + + # Step 7: Demonstrate embedding analysis + print("\n7. Embedding analysis...") + try: + embeddings = embeddings_result["embeddings"] + + # Basic statistics + print(f" - Embedding statistics:") + print(f" * Mean: {np.mean(embeddings):.4f}") + print(f" * Std: {np.std(embeddings):.4f}") + print(f" * Min: {np.min(embeddings):.4f}") + print(f" * Max: {np.max(embeddings):.4f}") + + # Similarity between tiles (if multiple tiles) + if embeddings.shape[0] > 1: + from sklearn.metrics.pairwise import cosine_similarity + + similarities = cosine_similarity(embeddings) + avg_similarity = np.mean( + similarities[np.triu_indices_from(similarities, k=1)] + ) + print(f" * Average tile similarity: {avg_similarity:.4f}") + + print(" ✓ Analysis complete") + + except Exception as e: + print(f" ✗ Error in embedding analysis: {e}") + + print(f"\n=== Example completed successfully! ===") + print(f"Output files saved in: {OUTPUT_DIR}/") + print("\nNext steps:") + print("- Use embeddings for similarity search") + print("- Fine-tune on downstream tasks") + print("- Integrate with SAM for enhanced segmentation") + + +def example_with_numpy_array(): + """Example showing how to use Clay embeddings with numpy arrays.""" + print("\n=== Numpy Array Example ===") + + # Create a synthetic 4-band image (RGBI) + synthetic_image = np.random.randint(0, 255, (256, 256, 4), dtype=np.uint8) + + try: + # Initialize Clay model + clay = ClayEmbeddings( + checkpoint_path="path/to/clay-model-checkpoint.ckpt", device="auto" + ) + + # Set synthetic image + clay.set_image( + source=synthetic_image, + sensor_type="naip", # Specify sensor type for numpy arrays + date="2023-06-01", + ) + + # Generate embeddings + result = clay.generate_embeddings(tile_size=256) + + print(f"Generated embeddings for synthetic image:") + print(f"- Shape: {result['embeddings'].shape}") + print(f"- Sensor: {result['sensor_type']}") + + except Exception as e: + print(f"Error in numpy array example: {e}") + + +if __name__ == "__main__": + main() + + # Uncomment to run numpy array example + # example_with_numpy_array() diff --git a/samgeo/__init__.py b/samgeo/__init__.py index ffff87e1..9090796a 100644 --- a/samgeo/__init__.py +++ b/samgeo/__init__.py @@ -8,3 +8,4 @@ from .samgeo import * from .samgeo2 import * from .common import show_image_gui +from .clay import Clay, load_embeddings diff --git a/samgeo/clay.py b/samgeo/clay.py new file mode 100644 index 00000000..c07ca4ef --- /dev/null +++ b/samgeo/clay.py @@ -0,0 +1,620 @@ +""" +Clay foundation model wrapper for geospatial embeddings. + +This module provides a wrapper around the Clay foundation model for generating +rich spectral embeddings from geospatial imagery. It integrates with the +segment-geospatial library's raster I/O infrastructure. +""" + +import os +import math +import datetime +import numpy as np +import torch +import cv2 +import rasterio +import warnings +from typing import Optional, Union, Tuple, Dict, List, Any +from pathlib import Path + +try: + from claymodel.model import ClayMAEModule + from claymodel.utils import posemb_sincos_2d_with_gsd + from torchvision.transforms import v2 + import yaml + from box import Box + + CLAY_AVAILABLE = True +except ImportError: + CLAY_AVAILABLE = False + +from .common import ( + check_file_path, + download_file, + transform_coords, + reproject, +) + + +# Default metadata for common sensors +DEFAULT_METADATA = { + "sentinel-2-l2a": { + "band_order": [ + "blue", + "green", + "red", + "rededge1", + "rededge2", + "rededge3", + "nir", + "nir08", + "swir16", + "swir22", + ], + "rgb_indices": [2, 1, 0], + "gsd": 10, + "bands": { + "mean": { + "blue": 1105.0, + "green": 1355.0, + "red": 1552.0, + "rededge1": 1887.0, + "rededge2": 2422.0, + "rededge3": 2630.0, + "nir": 2743.0, + "nir08": 2785.0, + "swir16": 2388.0, + "swir22": 1835.0, + }, + "std": { + "blue": 1809.0, + "green": 1757.0, + "red": 1888.0, + "rededge1": 1870.0, + "rededge2": 1732.0, + "rededge3": 1697.0, + "nir": 1742.0, + "nir08": 1648.0, + "swir16": 1470.0, + "swir22": 1379.0, + }, + "wavelength": { + "blue": 0.493, + "green": 0.56, + "red": 0.665, + "rededge1": 0.704, + "rededge2": 0.74, + "rededge3": 0.783, + "nir": 0.842, + "nir08": 0.865, + "swir16": 1.61, + "swir22": 2.19, + }, + }, + }, + "landsat-c2l2-sr": { + "band_order": ["red", "green", "blue", "nir08", "swir16", "swir22"], + "rgb_indices": [0, 1, 2], + "gsd": 30, + "bands": { + "mean": { + "red": 13705.0, + "green": 13310.0, + "blue": 12474.0, + "nir08": 17801.0, + "swir16": 14615.0, + "swir22": 12701.0, + }, + "std": { + "red": 9578.0, + "green": 9408.0, + "blue": 10144.0, + "nir08": 8277.0, + "swir16": 5300.0, + "swir22": 4522.0, + }, + "wavelength": { + "red": 0.65, + "green": 0.56, + "blue": 0.48, + "nir08": 0.86, + "swir16": 1.6, + "swir22": 2.2, + }, + }, + }, + "naip": { + "band_order": ["red", "green", "blue", "nir"], + "rgb_indices": [0, 1, 2], + "gsd": 1.0, + "bands": { + "mean": {"red": 110.16, "green": 115.41, "blue": 98.15, "nir": 139.04}, + "std": {"red": 47.23, "green": 39.82, "blue": 35.43, "nir": 49.86}, + "wavelength": {"red": 0.65, "green": 0.56, "blue": 0.48, "nir": 0.842}, + }, + }, +} + + +def normalize_timestamp(date): + """Normaize the timestamp for clay. Taken from https://github.com/Clay-foundation/stacchip/blob/main/stacchip/processors/prechip.py""" + week = date.isocalendar().week * 2 * np.pi / 52 + hour = date.hour * 2 * np.pi / 24 + + return (math.sin(week), math.cos(week)), (math.sin(hour), math.cos(hour)) + + +def normalize_latlon(bounds): + """Normalize latitude and longitude for clay. Taken from https://github.com/Clay-foundation/stacchip/blob/main/stacchip/processors/prechip.py""" + lon = bounds[0] + (bounds[2] - bounds[0]) / 2 + lat = bounds[1] + (bounds[3] - bounds[1]) / 2 + + lat = lat * np.pi / 180 + lon = lon * np.pi / 180 + + return (math.sin(lat), math.cos(lat)), (math.sin(lon), math.cos(lon)) + + +class Clay: + """ + Clay foundation model wrapper for generating geospatial embeddings. + + This class provides an interface to generate rich spectral embeddings from + geospatial imagery using the Clay foundation model. + """ + + def __init__( + self, + checkpoint_path: str, + model_size: str = "large", + metadata_path: Optional[str] = None, + device: str = "auto", + ): + """ + Initialize Clay embeddings model. + + Args: + checkpoint_path: Path to Clay model checkpoint + metadata_path: Path to Clay metadata YAML file (optional) + device: Device to run model on ('auto', 'cuda', 'cpu') + mask_ratio: Masking ratio for model (0.0 for inference) + shuffle: Whether to shuffle patches + """ + if not CLAY_AVAILABLE: + raise ImportError( + "Clay model dependencies not available. " + "Please install: pip install claymodel torch torchvision pyyaml python-box" + ) + + self.checkpoint_path = check_file_path(checkpoint_path, make_dirs=False) + if not os.path.exists(self.checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {self.checkpoint_path}") + + # Set device + if device == "auto": + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.device = torch.device(device) + + # Load metadata + if metadata_path and os.path.exists(metadata_path): + with open(metadata_path, "r") as f: + self.metadata = Box(yaml.safe_load(f)) + else: + self.metadata = Box(self.DEFAULT_METADATA) + if metadata_path: + warnings.warn( + f"Metadata file not found: {metadata_path}. Using defaults." + ) + + self.model_size = model_size + if self.model_size not in ["tiny", "small", "base", "large"]: + raise ValueError( + f"model_size must be one of: {['tiny','small','base','large']}" + ) + + # Load model + self._load_model() + + # Image processing attributes + self.image = None + self.source = None + self.sensor_type = None + self.raster_profile = None + + def _load_model(self): + """Load the Clay model from checkpoint.""" + try: + torch.set_default_device(self.device) + self.module = ClayMAEModule.load_from_checkpoint( + checkpoint_path=self.checkpoint_path, + model_size=self.model_size, + dolls=[16, 32, 64, 128, 256, 768, 1024], + doll_weights=[1, 1, 1, 1, 1, 1, 1], + mask_ratio=0.0, + shuffle=False, + ) + self.module.eval() + except Exception as e: + raise RuntimeError(f"Failed to load Clay model: {e}") + + def _detect_sensor_type( + self, src: rasterio.DatasetReader, source_path: Optional[str] = None + ) -> str: + """ + Detect sensor type from raster metadata and characteristics. + + Args: + src: Rasterio dataset reader + source_path: Optional source file path for filename-based detection + + Returns: + Detected sensor type string + """ + band_count = src.count + resolution = abs(src.transform[0]) # Pixel size + + # Try filename-based detection first + if source_path: + filename = os.path.basename(source_path).lower() + if "sentinel" in filename or "s2" in filename: + return "sentinel-2-l2a" + elif "landsat" in filename or "l8" in filename or "l9" in filename: + return "landsat-c2l2-sr" + elif "naip" in filename: + return "naip" + + # Fallback to resolution and band count heuristics + if band_count == 4 and resolution <= 5: + return "naip" # High-res 4-band imagery + elif band_count >= 6 and 25 <= resolution <= 35: + return "landsat-c2l2-sr" # Landsat resolution + elif band_count >= 10 and 8 <= resolution <= 12: + return "sentinel-2-l2a" # Sentinel-2 resolution + elif band_count == 4: + return "naip" # Default 4-band to NAIP + else: + # Default fallback + warnings.warn( + f"Could not detect sensor type (bands: {band_count}, " + f"resolution: {resolution:.1f}m). Defaulting to NAIP." + ) + return "naip" + + def _get_raster_center_latlon( + self, src: rasterio.DatasetReader + ) -> Tuple[float, float]: + """Get the center lat/lon of the raster.""" + bounds = src.bounds + center_x = (bounds.left + bounds.right) / 2 + center_y = (bounds.bottom + bounds.top) / 2 + + # Transform to WGS84 if needed + if src.crs != "EPSG:4326": + lon, lat = transform_coords([(center_x, center_y)], src.crs, "EPSG:4326")[0] + else: + lon, lat = center_x, center_y + + return lat, lon + + def _prepare_datacube( + self, + image: np.ndarray, + sensor_type: str, + lat: float, + lon: float, + date: Optional[datetime.datetime] = None, + gsd_override: Optional[float] = None, + ) -> Dict[str, torch.Tensor]: + """ + Prepare datacube for Clay model input. + + Args: + image: Input image array [H, W, C] + sensor_type: Detected sensor type + lat: Latitude of image center + lon: Longitude of image center + date: Image acquisition date + gsd_override: Override GSD value + + Returns: + Datacube dictionary for Clay model + """ + if date is None: + date = datetime.datetime.now() + + # Get sensor metadata + sensor_meta = self.metadata[sensor_type] + band_order = sensor_meta.band_order + gsd = gsd_override if gsd_override is not None else sensor_meta.gsd + + # Extract normalization parameters + means = [sensor_meta.bands.mean[band] for band in band_order] + stds = [sensor_meta.bands.std[band] for band in band_order] + wavelengths = [sensor_meta.bands.wavelength[band] for band in band_order] + + # Convert image to torch tensor and normalize + # Ensure we have the right number of bands + if image.shape[2] != len(band_order): + warnings.warn( + f"Image has {image.shape[2]} bands but sensor {sensor_type} " + f"expects {len(band_order)} bands. Using available bands." + ) + # Take only the available bands + num_bands = min(image.shape[2], len(band_order)) + image = image[:, :, :num_bands] + means = means[:num_bands] + stds = stds[:num_bands] + wavelengths = wavelengths[:num_bands] + + # Convert to tensor and transpose to [C, H, W] + pixels = torch.from_numpy(image.astype(np.float32)).permute(2, 0, 1) + + # Normalize + transform = v2.Compose([v2.Normalize(mean=means, std=stds)]) + pixels = transform(pixels).unsqueeze(0) # Add batch dimension + + # Prepare temporal encoding + time_norm = normalize_timestamp(date) + + # Prepare spatial encoding + lat_norm, lon_norm = normalize_latlon(lat, lon) + + # Create datacube + datacube = { + "pixels": pixels.to(self.device), + "time": torch.tensor( + time_norm + + time_norm, # Clay expects 4 elements: [week, hour, week, hour] + dtype=torch.float32, + device=self.device, + ).unsqueeze(0), + "latlon": torch.tensor( + lat_norm + + lon_norm, # Clay expects 4 elements: [sin_lat, cos_lat, sin_lon, cos_lon] + dtype=torch.float32, + device=self.device, + ).unsqueeze(0), + "gsd": torch.tensor(gsd, device=self.device), + "waves": torch.tensor(wavelengths, device=self.device), + } + + return datacube + + def set_image( + self, + source: Union[str, np.ndarray], + sensor_type: Optional[str] = None, + date: Optional[Union[str, datetime.datetime]] = None, + gsd_override: Optional[float] = None, + ): + """ + Set the input image for embedding generation. + + Args: + source: Path to image file or numpy array + sensor_type: Optional sensor type override + date: Image acquisition date + gsd_override: Override GSD value + """ + if isinstance(source, str): + if source.startswith("http"): + source = download_file(source) + + if not os.path.exists(source): + raise ValueError(f"Input path {source} does not exist.") + + # Read with rasterio for geospatial images + try: + with rasterio.open(source) as src: + # Read all bands + image = src.read() # Shape: [C, H, W] + image = np.transpose(image, (1, 2, 0)) # Convert to [H, W, C] + + # Store raster metadata + self.raster_profile = src.profile + + # Detect sensor type + if sensor_type is None: + sensor_type = self._detect_sensor_type(src, source) + + # Get image center coordinates + lat, lon = self._get_raster_center_latlon(src) + + except Exception: + # Fallback to OpenCV for regular images + image = cv2.imread(source) + if image is None: + raise ValueError(f"Could not read image: {source}") + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # Use defaults for non-geospatial images + sensor_type = sensor_type or "naip" + lat, lon = 0.0, 0.0 # Default coordinates + self.raster_profile = None + + elif isinstance(source, np.ndarray): + image = source + sensor_type = sensor_type or "naip" + lat, lon = 0.0, 0.0 + self.raster_profile = None + + else: + raise ValueError("Source must be a file path or numpy array") + + # Parse date if string + if isinstance(date, str): + try: + date = datetime.datetime.fromisoformat(date.replace("Z", "+00:00")) + except ValueError: + date = datetime.datetime.now() + warnings.warn(f"Could not parse date: {date}. Using current time.") + elif date is None: + date = datetime.datetime.now() + + # Store image and metadata + self.source = source if isinstance(source, str) else None + self.image = image + self.sensor_type = sensor_type + self.lat = lat + self.lon = lon + self.date = date + self.gsd_override = gsd_override + + print( + f"Set image: shape={image.shape}, sensor={sensor_type}, " + f"lat={lat:.4f}, lon={lon:.4f}" + ) + + def generate_embeddings( + self, tile_size: int = 256, overlap: float = 0.0 + ) -> Dict[str, Any]: + """ + Generate embeddings for the loaded image. + + Args: + tile_size: Size of tiles for processing large images + overlap: Overlap fraction between tiles (0.0 to 1.0) + + Returns: + Dictionary containing embeddings and metadata + """ + if self.image is None: + raise ValueError("No image loaded. Call set_image() first.") + + image = self.image + h, w = image.shape[:2] + + # If image is smaller than tile_size, process as single tile + if h <= tile_size and w <= tile_size: + # Pad image to tile_size if needed + if h < tile_size or w < tile_size: + pad_h = max(0, tile_size - h) + pad_w = max(0, tile_size - w) + image = np.pad(image, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect") + + # Generate single embedding + datacube = self._prepare_datacube( + image, + self.sensor_type, + self.lat, + self.lon, + self.date, + self.gsd_override, + ) + + with torch.no_grad(): + encoded_patches, _, _, _ = self.module.model.encoder(datacube) + # Extract class token (global embedding) + embedding = encoded_patches[:, 0, :].cpu().numpy() + + return { + "embeddings": embedding, + "tile_coords": [(0, 0, h, w)], + "image_shape": (h, w), + "sensor_type": self.sensor_type, + "lat": self.lat, + "lon": self.lon, + "date": self.date.isoformat() if self.date else None, + "num_tiles": 1, + } + + else: + # Process as overlapping tiles + step_size = int(tile_size * (1 - overlap)) + embeddings = [] + tile_coords = [] + + for y in range(0, h - tile_size + 1, step_size): + for x in range(0, w - tile_size + 1, step_size): + # Extract tile + tile = image[y : y + tile_size, x : x + tile_size] + + # Prepare datacube for this tile + datacube = self._prepare_datacube( + tile, + self.sensor_type, + self.lat, + self.lon, + self.date, + self.gsd_override, + ) + + # Generate embedding + with torch.no_grad(): + encoded_patches, _, _, _ = self.module.model.encoder(datacube) + embedding = encoded_patches[:, 0, :].cpu().numpy() + + embeddings.append(embedding) + tile_coords.append((x, y, x + tile_size, y + tile_size)) + + return { + "embeddings": np.vstack(embeddings), + "tile_coords": tile_coords, + "image_shape": (h, w), + "sensor_type": self.sensor_type, + "lat": self.lat, + "lon": self.lon, + "date": self.date.isoformat() if self.date else None, + "num_tiles": len(embeddings), + } + + def save_embeddings( + self, embeddings_result: Dict[str, Any], output_path: str, format: str = "npz" + ): + """ + Save embeddings to file. + + Args: + embeddings_result: Result from generate_embeddings() + output_path: Output file path + format: Output format ('npz', 'pt') + """ + output_path = check_file_path(output_path) + + if format == "npz": + np.savez_compressed( + output_path, + embeddings=embeddings_result["embeddings"], + tile_coords=np.array(embeddings_result["tile_coords"]), + image_shape=np.array(embeddings_result["image_shape"]), + sensor_type=embeddings_result["sensor_type"], + lat=embeddings_result["lat"], + lon=embeddings_result["lon"], + date=embeddings_result["date"], + num_tiles=embeddings_result["num_tiles"], + ) + elif format == "pt": + torch.save(embeddings_result, output_path) + else: + raise ValueError(f"Unsupported format: {format}") + + print(f"Saved embeddings to {output_path}") + + +def load_embeddings(file_path: str) -> Dict[str, Any]: + """ + Load embeddings from file. + + Args: + file_path: Path to embeddings file + + Returns: + Embeddings dictionary + """ + if file_path.endswith(".npz"): + data = np.load(file_path, allow_pickle=True) + return { + "embeddings": data["embeddings"], + "tile_coords": data["tile_coords"].tolist(), + "image_shape": tuple(data["image_shape"]), + "sensor_type": str(data["sensor_type"]), + "lat": float(data["lat"]), + "lon": float(data["lon"]), + "date": str(data["date"]) if data["date"] != "None" else None, + "num_tiles": int(data["num_tiles"]), + } + elif file_path.endswith(".pt"): + return torch.load(file_path, map_location="cpu") + else: + raise ValueError(f"Unsupported file format: {file_path}") diff --git a/samgeo/clay_metadata.yaml b/samgeo/clay_metadata.yaml new file mode 100644 index 00000000..d18ebbae --- /dev/null +++ b/samgeo/clay_metadata.yaml @@ -0,0 +1,295 @@ +sentinel-2-l2a: + band_order: + - blue + - green + - red + - rededge1 + - rededge2 + - rededge3 + - nir + - nir08 + - swir16 + - swir22 + rgb_indices: + - 2 + - 1 + - 0 + gsd: 10 + bands: + mean: + blue: 1105. + green: 1355. + red: 1552. + rededge1: 1887. + rededge2: 2422. + rededge3: 2630. + nir: 2743. + nir08: 2785. + swir16: 2388. + swir22: 1835. + std: + blue: 1809. + green: 1757. + red: 1888. + rededge1: 1870. + rededge2: 1732. + rededge3: 1697. + nir: 1742. + nir08: 1648. + swir16: 1470. + swir22: 1379. + wavelength: + blue: 0.493 + green: 0.56 + red: 0.665 + rededge1: 0.704 + rededge2: 0.74 + rededge3: 0.783 + nir: 0.842 + nir08: 0.865 + swir16: 1.61 + swir22: 2.19 +planetscope-sr: + band_order: + - coastal_blue + - blue + - green_i + - green + - yellow + - red + - rededge + - nir + rgb_indices: + - 5 + - 3 + - 1 + gsd: 5 + bands: + mean: + coastal_blue: 1720. + blue: 1715. + green_i: 1913. + green: 2088. + yellow: 2274. + red: 2290. + rededge: 2613. + nir: 3970. + std: + coastal_blue: 747. + blue: 698. + green_i: 739. + green: 768. + yellow: 849. + red: 868. + rededge: 849. + nir: 914. + wavelength: + coastal_blue: 0.443 + blue: 0.490 + green_i: 0.531 + green: 0.565 + yellow: 0.610 + red: 0.665 + rededge: 0.705 + nir: 0.865 +landsat-c2l1: + band_order: + - red + - green + - blue + - nir08 + - swir16 + - swir22 + rgb_indices: + - 0 + - 1 + - 2 + gsd: 30 + bands: + mean: + red: 10678. + green: 10563. + blue: 11083. + nir08: 14792. + swir16: 12276. + swir22: 10114. + std: + red: 6025. + green: 5411. + blue: 5468. + nir08: 6746. + swir16: 5897. + swir22: 4850. + wavelength: + red: 0.65 + green: 0.56 + blue: 0.48 + nir08: 0.86 + swir16: 1.6 + swir22: 2.2 +landsat-c2l2-sr: + band_order: + - red + - green + - blue + - nir08 + - swir16 + - swir22 + rgb_indices: + - 0 + - 1 + - 2 + gsd: 30 + bands: + mean: + red: 13705. + green: 13310. + blue: 12474. + nir08: 17801. + swir16: 14615. + swir22: 12701. + std: + red: 9578. + green: 9408. + blue: 10144. + nir08: 8277. + swir16: 5300. + swir22: 4522. + wavelength: + red: 0.65 + green: 0.56 + blue: 0.48 + nir08: 0.86 + swir16: 1.6 + swir22: 2.2 +naip: + band_order: + - red + - green + - blue + - nir + rgb_indices: + - 0 + - 1 + - 2 + gsd: 1.0 + bands: + mean: + red: 110.16 + green: 115.41 + blue: 98.15 + nir: 139.04 + std: + red: 47.23 + green: 39.82 + blue: 35.43 + nir: 49.86 + wavelength: + red: 0.65 + green: 0.56 + blue: 0.48 + nir: 0.842 +linz: + band_order: + - red + - green + - blue + rgb_indices: + - 0 + - 1 + - 2 + gsd: 0.5 + bands: + mean: + red: 89.96 + green: 99.46 + blue: 89.51 + std: + red: 41.83 + green: 36.96 + blue: 31.45 + wavelength: + red: 0.635 + green: 0.555 + blue: 0.465 +sentinel-1-rtc: + band_order: + - vv + - vh + gsd: 10 + bands: + mean: + vv: -12.113 + vh: -18.673 + std: + vv: 8.314 + vh: 8.017 + wavelength: + vv: 3.5 + vh: 4.0 +modis: + band_order: + - sur_refl_b01 + - sur_refl_b02 + - sur_refl_b03 + - sur_refl_b04 + - sur_refl_b05 + - sur_refl_b06 + - sur_refl_b07 + rgb_indices: + - 0 + - 3 + - 2 + gsd: 500 + bands: + mean: + sur_refl_b01: 1072. + sur_refl_b02: 1624. + sur_refl_b03: 931. + sur_refl_b04: 1023. + sur_refl_b05: 1599. + sur_refl_b06: 1404. + sur_refl_b07: 1051. + std: + sur_refl_b01: 1643. + sur_refl_b02: 1878. + sur_refl_b03: 1449. + sur_refl_b04: 1538. + sur_refl_b05: 1763. + sur_refl_b06: 1618. + sur_refl_b07: 1396. + wavelength: + sur_refl_b01: .645 + sur_refl_b02: .858 + sur_refl_b03: .469 + sur_refl_b04: .555 + sur_refl_b05: 1.240 + sur_refl_b06: 1.640 + sur_refl_b07: 2.130 +satellogic-MSI-L1D: + band_order: + - red + - green + - blue + - nir + rgb_indices: + - 0 + - 1 + - 2 + gsd: 1.0 + bands: + mean: + red: 1451.54 + green: 1456.54 + blue: 1543.22 + nir: 2132.68 + std: + red: 995.48 + green: 771.29 + blue: 708.86 + nir: 1236.71 + wavelength: + red: 0.640 + green: 0.545 + blue: 0.480 + nir: 0.825