From fe4b7ee4e3e58271240185ea0d0fb7ff50edc498 Mon Sep 17 00:00:00 2001 From: hgherzog Date: Thu, 25 Sep 2025 19:50:38 +0000 Subject: [PATCH 01/14] allow unet to resize outputs --- rslearn/models/copernicus_fm.py | 206 ++++++++++++++++++++++++++++++++ rslearn/models/copernicusfm.py | 206 ++++++++++++++++++++++++++++++++ rslearn/models/unet.py | 9 ++ 3 files changed, 421 insertions(+) create mode 100644 rslearn/models/copernicus_fm.py create mode 100644 rslearn/models/copernicusfm.py diff --git a/rslearn/models/copernicus_fm.py b/rslearn/models/copernicus_fm.py new file mode 100644 index 00000000..ac313fed --- /dev/null +++ b/rslearn/models/copernicus_fm.py @@ -0,0 +1,206 @@ +"""Copernicus FM model.""" + +from enum import Enum +import logging +import math + +import torch +import torch.nn.functional as F +from einops import rearrange +from upath import UPath + +from .copernicusfm.src.model_vit import vit_base_patch16 + + +logger = logging.getLogger(__name__) +class CopernicusFMModality(Enum): + """Modality for Copernicus FM.""" + SENTINEL2_L2A = "sentinel2_l2a" + SENTINEL1 = "sentinel1" + +MODALITY_TO_WAVELENGTH_BANDWIDTHS: dict[str, dict[str, list]] = { + # https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/configs/dataset/cobench_eurosat_s2.yaml + CopernicusFMModality.SENTINEL2_L2A.value: { + "band_names": [ + "B01", + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B8A", + "B09", + "B10", + "B11", + "B12", + ], + "band_wavelengths": [ + 440, + 490, + 560, + 665, + 705, + 740, + 783, + 842, + 860, + 940, + 1370, + 1610, + 2190, + ], + "band_bandwidths": [20, 65, 35, 30, 15, 15, 20, 115, 20, 20, 30, 90, 180], + }, + # https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/configs/dataset/cobench_eurosat_s1.yaml + CopernicusFMModality.SENTINEL1.value: { + "band_names": ["vv", "vh"], + "band_wavelengths": [50000000, 50000000], + "band_bandwidths": [1e9, 1e9], + }, +} + + +class CopernicusFM(torch.nn.Module): + """Wrapper for Copernicus FM to ingest Masked Helios Sample.""" + + image_resolution = 224 + patch_size = 16 + input_mode = "spectral" + supported_modalities = [CopernicusFMModality.SENTINEL2_L2A.value, CopernicusFMModality.SENTINEL1.value] + + def __init__( + self, band_order: dict[str, list[str]], load_directory: str = "/weka/dfive-default/helios/models/copernicusfm" + ) -> None: + """Initialize the Copernicus FM wrapper. + + Args: + load_directory: The directory to load from + """ + super().__init__() + + # global_pool=True so that we initialize the fc_norm layer + self.band_order = band_order + self.model = vit_base_patch16(num_classes=10, global_pool=True) + check_point = torch.load( + UPath(load_directory) / "CopernicusFM_ViT_base_varlang_e100.pth" + ) + if "model" in check_point: + state_dict = check_point["model"] + else: + state_dict = check_point + self.model.load_state_dict(state_dict, strict=False) + + # take MODALITY_TO_WAVELENGTH_BANDWIDTHS and rearrage it so that it has the same + # ordering as the Helios band orders, defined by Modality.band_order + self.modality_to_wavelength_bandwidths = {} + for modality in self.supported_modalities: + wavelength_bandwidths = MODALITY_TO_WAVELENGTH_BANDWIDTHS[modality] + band_order = MODALITY_TO_WAVELENGTH_BANDWIDTHS[modality]["band_names"] + wavelengths = [] + bandwidths = [] + band_order = self.band_order.get(modality, None) + if band_order is None: + logger.warning(f"Band order for modality {modality} not found in band_order dictionary, unable to use this modality unless specified") + continue + for b in band_order: + cfm_idx = wavelength_bandwidths["band_names"].index(b) + wavelengths.append(wavelength_bandwidths["band_wavelengths"][cfm_idx]) + bandwidths.append(wavelength_bandwidths["band_bandwidths"][cfm_idx]) + self.modality_to_wavelength_bandwidths[modality] = { + "band_bandwidths": bandwidths, + "band_wavelengths": wavelengths, + } + + def _resize_data( + self, data: torch.Tensor + ) -> list[torch.Tensor]: + """Process individual modality data. + + Args: + data: Input tensor of shape [B, C, H, W] + + Returns: + list of tensors of shape [B, C, H, W] + """ + # Get original dimensions + original_height = data.shape[2] + new_height = ( + self.patch_size if original_height == 1 else self.image_resolution + ) + data = F.interpolate( + data, + size=(new_height, new_height), + mode="bilinear", + align_corners=False, + ) + return data + + def prepare_input( + self, + inputs: dict[str, torch.Tensor], + ) -> tuple[list[torch.Tensor], list[int], list[int]]: + """Prepare input for the CopernicusFM model from MaskedHeliosSample.""" + wavelengths: list[int] = [] + bandwidths: list[int] = [] + all_processed_data: list[list[torch.Tensor]] = [] + for modality in inputs.keys(): + if modality not in self.supported_modalities: + logger.warning( + f"Skipping modality {modality} as it is not in the supported " + f"modalities list {self.supported_modalities}" + ) + continue + + data = inputs[modality] + + if data is None: + continue + + all_processed_data.append(self._resize_data(data)) + wavelengths.extend( + self.modality_to_wavelength_bandwidths[modality]["band_wavelengths"] + ) + bandwidths.extend( + self.modality_to_wavelength_bandwidths[modality]["band_bandwidths"] + ) + + concatenated_processed_data = torch.cat(all_processed_data, dim=1) + return concatenated_processed_data, wavelengths, bandwidths + + def forward( + self, + inputs: list[dict[str, torch.Tensor]], + ) -> torch.Tensor: + """Forward pass through CopernicusFM model.""" + batch_inputs = {key: torch.stack([inp[key] for inp in inputs], dim=0) for key in inputs[0].keys()} + # Prepare input + data, wavelengths, bandwidths = self.prepare_input( + batch_inputs + ) + meta = torch.full( + (1, 4), float("nan"), device=data.device + ) # [lon, lat, delta_time, patch_token_area], assume unknown + # "The embed tensor contains the encoded image features, which can be used for downstream tasks." + _, timestep_output = self.model( + data, + meta, + wavelengths, + bandwidths, + None, + self.input_mode, + self.patch_size, + ) + # no norm, following + # https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/foundation_models/CopernicusFM/models_dwv_seg.py + side = math.isqrt(timestep_output.shape[1]) + output_features = rearrange( + timestep_output, "b (h w) c -> b c h w ", h=side, w=side + ) + return [output_features] + + def get_backbone_channels(self) -> list[tuple[int, int]]: + """Returns the output channels of this model when used as a backbone.""" + # TODO: load this from a constant depending on the model size + return [(self.patch_size, 768)] diff --git a/rslearn/models/copernicusfm.py b/rslearn/models/copernicusfm.py new file mode 100644 index 00000000..ac313fed --- /dev/null +++ b/rslearn/models/copernicusfm.py @@ -0,0 +1,206 @@ +"""Copernicus FM model.""" + +from enum import Enum +import logging +import math + +import torch +import torch.nn.functional as F +from einops import rearrange +from upath import UPath + +from .copernicusfm.src.model_vit import vit_base_patch16 + + +logger = logging.getLogger(__name__) +class CopernicusFMModality(Enum): + """Modality for Copernicus FM.""" + SENTINEL2_L2A = "sentinel2_l2a" + SENTINEL1 = "sentinel1" + +MODALITY_TO_WAVELENGTH_BANDWIDTHS: dict[str, dict[str, list]] = { + # https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/configs/dataset/cobench_eurosat_s2.yaml + CopernicusFMModality.SENTINEL2_L2A.value: { + "band_names": [ + "B01", + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B8A", + "B09", + "B10", + "B11", + "B12", + ], + "band_wavelengths": [ + 440, + 490, + 560, + 665, + 705, + 740, + 783, + 842, + 860, + 940, + 1370, + 1610, + 2190, + ], + "band_bandwidths": [20, 65, 35, 30, 15, 15, 20, 115, 20, 20, 30, 90, 180], + }, + # https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/configs/dataset/cobench_eurosat_s1.yaml + CopernicusFMModality.SENTINEL1.value: { + "band_names": ["vv", "vh"], + "band_wavelengths": [50000000, 50000000], + "band_bandwidths": [1e9, 1e9], + }, +} + + +class CopernicusFM(torch.nn.Module): + """Wrapper for Copernicus FM to ingest Masked Helios Sample.""" + + image_resolution = 224 + patch_size = 16 + input_mode = "spectral" + supported_modalities = [CopernicusFMModality.SENTINEL2_L2A.value, CopernicusFMModality.SENTINEL1.value] + + def __init__( + self, band_order: dict[str, list[str]], load_directory: str = "/weka/dfive-default/helios/models/copernicusfm" + ) -> None: + """Initialize the Copernicus FM wrapper. + + Args: + load_directory: The directory to load from + """ + super().__init__() + + # global_pool=True so that we initialize the fc_norm layer + self.band_order = band_order + self.model = vit_base_patch16(num_classes=10, global_pool=True) + check_point = torch.load( + UPath(load_directory) / "CopernicusFM_ViT_base_varlang_e100.pth" + ) + if "model" in check_point: + state_dict = check_point["model"] + else: + state_dict = check_point + self.model.load_state_dict(state_dict, strict=False) + + # take MODALITY_TO_WAVELENGTH_BANDWIDTHS and rearrage it so that it has the same + # ordering as the Helios band orders, defined by Modality.band_order + self.modality_to_wavelength_bandwidths = {} + for modality in self.supported_modalities: + wavelength_bandwidths = MODALITY_TO_WAVELENGTH_BANDWIDTHS[modality] + band_order = MODALITY_TO_WAVELENGTH_BANDWIDTHS[modality]["band_names"] + wavelengths = [] + bandwidths = [] + band_order = self.band_order.get(modality, None) + if band_order is None: + logger.warning(f"Band order for modality {modality} not found in band_order dictionary, unable to use this modality unless specified") + continue + for b in band_order: + cfm_idx = wavelength_bandwidths["band_names"].index(b) + wavelengths.append(wavelength_bandwidths["band_wavelengths"][cfm_idx]) + bandwidths.append(wavelength_bandwidths["band_bandwidths"][cfm_idx]) + self.modality_to_wavelength_bandwidths[modality] = { + "band_bandwidths": bandwidths, + "band_wavelengths": wavelengths, + } + + def _resize_data( + self, data: torch.Tensor + ) -> list[torch.Tensor]: + """Process individual modality data. + + Args: + data: Input tensor of shape [B, C, H, W] + + Returns: + list of tensors of shape [B, C, H, W] + """ + # Get original dimensions + original_height = data.shape[2] + new_height = ( + self.patch_size if original_height == 1 else self.image_resolution + ) + data = F.interpolate( + data, + size=(new_height, new_height), + mode="bilinear", + align_corners=False, + ) + return data + + def prepare_input( + self, + inputs: dict[str, torch.Tensor], + ) -> tuple[list[torch.Tensor], list[int], list[int]]: + """Prepare input for the CopernicusFM model from MaskedHeliosSample.""" + wavelengths: list[int] = [] + bandwidths: list[int] = [] + all_processed_data: list[list[torch.Tensor]] = [] + for modality in inputs.keys(): + if modality not in self.supported_modalities: + logger.warning( + f"Skipping modality {modality} as it is not in the supported " + f"modalities list {self.supported_modalities}" + ) + continue + + data = inputs[modality] + + if data is None: + continue + + all_processed_data.append(self._resize_data(data)) + wavelengths.extend( + self.modality_to_wavelength_bandwidths[modality]["band_wavelengths"] + ) + bandwidths.extend( + self.modality_to_wavelength_bandwidths[modality]["band_bandwidths"] + ) + + concatenated_processed_data = torch.cat(all_processed_data, dim=1) + return concatenated_processed_data, wavelengths, bandwidths + + def forward( + self, + inputs: list[dict[str, torch.Tensor]], + ) -> torch.Tensor: + """Forward pass through CopernicusFM model.""" + batch_inputs = {key: torch.stack([inp[key] for inp in inputs], dim=0) for key in inputs[0].keys()} + # Prepare input + data, wavelengths, bandwidths = self.prepare_input( + batch_inputs + ) + meta = torch.full( + (1, 4), float("nan"), device=data.device + ) # [lon, lat, delta_time, patch_token_area], assume unknown + # "The embed tensor contains the encoded image features, which can be used for downstream tasks." + _, timestep_output = self.model( + data, + meta, + wavelengths, + bandwidths, + None, + self.input_mode, + self.patch_size, + ) + # no norm, following + # https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/foundation_models/CopernicusFM/models_dwv_seg.py + side = math.isqrt(timestep_output.shape[1]) + output_features = rearrange( + timestep_output, "b (h w) c -> b c h w ", h=side, w=side + ) + return [output_features] + + def get_backbone_channels(self) -> list[tuple[int, int]]: + """Returns the output channels of this model when used as a backbone.""" + # TODO: load this from a constant depending on the model size + return [(self.patch_size, 768)] diff --git a/rslearn/models/unet.py b/rslearn/models/unet.py index 5d0353ab..6d06d11f 100644 --- a/rslearn/models/unet.py +++ b/rslearn/models/unet.py @@ -20,6 +20,7 @@ def __init__( conv_layers_per_resolution: int = 1, kernel_size: int = 3, num_channels: dict[int, int] = {}, + original_size_to_interpolate: tuple[int, int] | None = None, ) -> None: """Initialize a UNetDecoder. @@ -33,6 +34,7 @@ def __init__( kernel_size: kernel size to use in convolutional layers num_channels: override number of output channels to use at different downsample factors. + original_size_to_interpolate: the original size to interpolate the output to. """ super().__init__() @@ -123,6 +125,11 @@ def __init__( ) layers.append(torch.nn.Sequential(*cur_layers)) self.layers = torch.nn.ModuleList(layers) + self.original_size_to_interpolate = original_size_to_interpolate + + def _resize(self, features: torch.Tensor) -> torch.Tensor: + """Interpolate the features to the original size.""" + return F.interpolate(features, size=self.original_size_to_interpolate, mode="bilinear", align_corners=False) def forward( self, in_features: list[torch.Tensor], inputs: list[dict[str, Any]] @@ -141,4 +148,6 @@ def forward( cur_features = self.layers[0](in_features[0]) for in_feat, layer in zip(in_features[1:], self.layers[1:]): cur_features = layer(torch.cat([cur_features, in_feat], dim=1)) + if self.original_size_to_interpolate is not None: + cur_features = self._resize(cur_features) return cur_features From 75a0775aa0a4e905c9ab59a84f59e4595dfe8b8a Mon Sep 17 00:00:00 2001 From: hgherzog Date: Thu, 25 Sep 2025 19:54:01 +0000 Subject: [PATCH 02/14] Add source files --- rslearn/models/copernicus_fm.py | 206 -------------------------------- 1 file changed, 206 deletions(-) delete mode 100644 rslearn/models/copernicus_fm.py diff --git a/rslearn/models/copernicus_fm.py b/rslearn/models/copernicus_fm.py deleted file mode 100644 index ac313fed..00000000 --- a/rslearn/models/copernicus_fm.py +++ /dev/null @@ -1,206 +0,0 @@ -"""Copernicus FM model.""" - -from enum import Enum -import logging -import math - -import torch -import torch.nn.functional as F -from einops import rearrange -from upath import UPath - -from .copernicusfm.src.model_vit import vit_base_patch16 - - -logger = logging.getLogger(__name__) -class CopernicusFMModality(Enum): - """Modality for Copernicus FM.""" - SENTINEL2_L2A = "sentinel2_l2a" - SENTINEL1 = "sentinel1" - -MODALITY_TO_WAVELENGTH_BANDWIDTHS: dict[str, dict[str, list]] = { - # https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/configs/dataset/cobench_eurosat_s2.yaml - CopernicusFMModality.SENTINEL2_L2A.value: { - "band_names": [ - "B01", - "B02", - "B03", - "B04", - "B05", - "B06", - "B07", - "B08", - "B8A", - "B09", - "B10", - "B11", - "B12", - ], - "band_wavelengths": [ - 440, - 490, - 560, - 665, - 705, - 740, - 783, - 842, - 860, - 940, - 1370, - 1610, - 2190, - ], - "band_bandwidths": [20, 65, 35, 30, 15, 15, 20, 115, 20, 20, 30, 90, 180], - }, - # https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/configs/dataset/cobench_eurosat_s1.yaml - CopernicusFMModality.SENTINEL1.value: { - "band_names": ["vv", "vh"], - "band_wavelengths": [50000000, 50000000], - "band_bandwidths": [1e9, 1e9], - }, -} - - -class CopernicusFM(torch.nn.Module): - """Wrapper for Copernicus FM to ingest Masked Helios Sample.""" - - image_resolution = 224 - patch_size = 16 - input_mode = "spectral" - supported_modalities = [CopernicusFMModality.SENTINEL2_L2A.value, CopernicusFMModality.SENTINEL1.value] - - def __init__( - self, band_order: dict[str, list[str]], load_directory: str = "/weka/dfive-default/helios/models/copernicusfm" - ) -> None: - """Initialize the Copernicus FM wrapper. - - Args: - load_directory: The directory to load from - """ - super().__init__() - - # global_pool=True so that we initialize the fc_norm layer - self.band_order = band_order - self.model = vit_base_patch16(num_classes=10, global_pool=True) - check_point = torch.load( - UPath(load_directory) / "CopernicusFM_ViT_base_varlang_e100.pth" - ) - if "model" in check_point: - state_dict = check_point["model"] - else: - state_dict = check_point - self.model.load_state_dict(state_dict, strict=False) - - # take MODALITY_TO_WAVELENGTH_BANDWIDTHS and rearrage it so that it has the same - # ordering as the Helios band orders, defined by Modality.band_order - self.modality_to_wavelength_bandwidths = {} - for modality in self.supported_modalities: - wavelength_bandwidths = MODALITY_TO_WAVELENGTH_BANDWIDTHS[modality] - band_order = MODALITY_TO_WAVELENGTH_BANDWIDTHS[modality]["band_names"] - wavelengths = [] - bandwidths = [] - band_order = self.band_order.get(modality, None) - if band_order is None: - logger.warning(f"Band order for modality {modality} not found in band_order dictionary, unable to use this modality unless specified") - continue - for b in band_order: - cfm_idx = wavelength_bandwidths["band_names"].index(b) - wavelengths.append(wavelength_bandwidths["band_wavelengths"][cfm_idx]) - bandwidths.append(wavelength_bandwidths["band_bandwidths"][cfm_idx]) - self.modality_to_wavelength_bandwidths[modality] = { - "band_bandwidths": bandwidths, - "band_wavelengths": wavelengths, - } - - def _resize_data( - self, data: torch.Tensor - ) -> list[torch.Tensor]: - """Process individual modality data. - - Args: - data: Input tensor of shape [B, C, H, W] - - Returns: - list of tensors of shape [B, C, H, W] - """ - # Get original dimensions - original_height = data.shape[2] - new_height = ( - self.patch_size if original_height == 1 else self.image_resolution - ) - data = F.interpolate( - data, - size=(new_height, new_height), - mode="bilinear", - align_corners=False, - ) - return data - - def prepare_input( - self, - inputs: dict[str, torch.Tensor], - ) -> tuple[list[torch.Tensor], list[int], list[int]]: - """Prepare input for the CopernicusFM model from MaskedHeliosSample.""" - wavelengths: list[int] = [] - bandwidths: list[int] = [] - all_processed_data: list[list[torch.Tensor]] = [] - for modality in inputs.keys(): - if modality not in self.supported_modalities: - logger.warning( - f"Skipping modality {modality} as it is not in the supported " - f"modalities list {self.supported_modalities}" - ) - continue - - data = inputs[modality] - - if data is None: - continue - - all_processed_data.append(self._resize_data(data)) - wavelengths.extend( - self.modality_to_wavelength_bandwidths[modality]["band_wavelengths"] - ) - bandwidths.extend( - self.modality_to_wavelength_bandwidths[modality]["band_bandwidths"] - ) - - concatenated_processed_data = torch.cat(all_processed_data, dim=1) - return concatenated_processed_data, wavelengths, bandwidths - - def forward( - self, - inputs: list[dict[str, torch.Tensor]], - ) -> torch.Tensor: - """Forward pass through CopernicusFM model.""" - batch_inputs = {key: torch.stack([inp[key] for inp in inputs], dim=0) for key in inputs[0].keys()} - # Prepare input - data, wavelengths, bandwidths = self.prepare_input( - batch_inputs - ) - meta = torch.full( - (1, 4), float("nan"), device=data.device - ) # [lon, lat, delta_time, patch_token_area], assume unknown - # "The embed tensor contains the encoded image features, which can be used for downstream tasks." - _, timestep_output = self.model( - data, - meta, - wavelengths, - bandwidths, - None, - self.input_mode, - self.patch_size, - ) - # no norm, following - # https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/foundation_models/CopernicusFM/models_dwv_seg.py - side = math.isqrt(timestep_output.shape[1]) - output_features = rearrange( - timestep_output, "b (h w) c -> b c h w ", h=side, w=side - ) - return [output_features] - - def get_backbone_channels(self) -> list[tuple[int, int]]: - """Returns the output channels of this model when used as a backbone.""" - # TODO: load this from a constant depending on the model size - return [(self.patch_size, 768)] From d2b533c72b697601f2c9750e42529768b23bc1c5 Mon Sep 17 00:00:00 2001 From: hgherzog Date: Thu, 25 Sep 2025 19:59:30 +0000 Subject: [PATCH 03/14] add source files --- rslearn/models/copernicusfm/__init__.py | 1 + rslearn/models/copernicusfm/aurora/area.py | 146 +++++ rslearn/models/copernicusfm/aurora/fourier.py | 134 +++++ .../copernicusfm/dynamic_hypernetwork.py | 529 ++++++++++++++++++ .../copernicusfm/flexivit/patch_embed.py | 260 +++++++++ rslearn/models/copernicusfm/flexivit/utils.py | 69 +++ rslearn/models/copernicusfm/model_vit.py | 348 ++++++++++++ rslearn/models/copernicusfm/util/lr_sched.py | 29 + rslearn/models/copernicusfm/util/misc.py | 404 +++++++++++++ rslearn/models/copernicusfm/util/pos_embed.py | 216 +++++++ 10 files changed, 2136 insertions(+) create mode 100644 rslearn/models/copernicusfm/__init__.py create mode 100644 rslearn/models/copernicusfm/aurora/area.py create mode 100644 rslearn/models/copernicusfm/aurora/fourier.py create mode 100644 rslearn/models/copernicusfm/dynamic_hypernetwork.py create mode 100644 rslearn/models/copernicusfm/flexivit/patch_embed.py create mode 100644 rslearn/models/copernicusfm/flexivit/utils.py create mode 100644 rslearn/models/copernicusfm/model_vit.py create mode 100644 rslearn/models/copernicusfm/util/lr_sched.py create mode 100644 rslearn/models/copernicusfm/util/misc.py create mode 100644 rslearn/models/copernicusfm/util/pos_embed.py diff --git a/rslearn/models/copernicusfm/__init__.py b/rslearn/models/copernicusfm/__init__.py new file mode 100644 index 00000000..b97b4284 --- /dev/null +++ b/rslearn/models/copernicusfm/__init__.py @@ -0,0 +1 @@ +# type: ignore diff --git a/rslearn/models/copernicusfm/aurora/area.py b/rslearn/models/copernicusfm/aurora/area.py new file mode 100644 index 00000000..d980e66c --- /dev/null +++ b/rslearn/models/copernicusfm/aurora/area.py @@ -0,0 +1,146 @@ +"""Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" + +import torch + +__all__ = ["area", "compute_patch_areas", "radius_earth"] + + +# float: Radius of the earth in kilometers. +radius_earth = 6378137 / 1000 + + +def area(polygon: torch.Tensor) -> torch.Tensor: + """Compute the area of a polygon specified by latitudes and longitudes in degrees. + + This function is a PyTorch port of the PyPI package `area`. In particular, it is heavily + inspired by the following file: + + https://github.com/scisco/area/blob/9d9549d6ebffcbe4bffe11b71efa2d406d1c9fe9/area/__init__.py + + Args: + polygon (:class:`torch.Tensor`): Polygon of the shape `(*b, n, 2)` where `b` is an optional + multidimensional batch size, `n` is the number of points of the polygon, and 2 + concatenates first latitudes and then longitudes. The polygon does not have be closed. + + Returns: + :class:`torch.Tensor`: Area in square kilometers. + """ + # Be sure to close the loop. + polygon = torch.cat((polygon, polygon[..., -1:, :]), axis=-2) + + area = torch.zeros(polygon.shape[:-2], dtype=polygon.dtype, device=polygon.device) + n = polygon.shape[-2] # Number of points of the polygon + + rad = torch.deg2rad # Convert degrees to radians. + + if n > 2: + for i in range(n): + i_lower = i + i_middle = (i + 1) % n + i_upper = (i + 2) % n + + lon_lower = polygon[..., i_lower, 1] + lat_middle = polygon[..., i_middle, 0] + lon_upper = polygon[..., i_upper, 1] + + area = area + (rad(lon_upper) - rad(lon_lower)) * torch.sin(rad(lat_middle)) + + area = area * radius_earth * radius_earth / 2 + + return torch.abs(area) + + +def expand_matrix(matrix: torch.Tensor) -> torch.Tensor: + """Expand matrix by adding one row and one column to each side, using + linear interpolation. + + Args: + matrix (:class:`torch.Tensor`): Matrix to expand. + + Returns: + :class:`torch.Tensor`: `matrix`, but with two extra rows and two extra columns. + """ + # Add top and bottom rows. + matrix = torch.cat( + ( + 2 * matrix[0:1] - matrix[1:2], + matrix, + 2 * matrix[-1:] - matrix[-2:-1], + ), + dim=0, + ) + + # Add left and right columns. + matrix = torch.cat( + ( + 2 * matrix[:, 0:1] - matrix[:, 1:2], + matrix, + 2 * matrix[:, -1:] - matrix[:, -2:-1], + ), + dim=1, + ) + + return matrix + + +def compute_patch_areas(lat: torch.Tensor, lon: torch.Tensor) -> torch.Tensor: + """A pair of latitude and longitude matrices defines a number non-intersecting patches on the + Earth. For a global grid, these patches span the entire surface of the Earth. For a local grid, + the patches might span only a country or a continent. This function computes the area of every + specified patch. + + To divide the Earth into patches, the idea is to let a grid point be the _center_ of the + corresponding patch. The vertices of this patch will then sit exactly inbetween the grid + point and the grid points immediately diagonally and non-diagonally above, below, left, and + right. For a grid point at the very top of the grid, for example, there is no immediately above + grid point. In that case, we enlarge the grid by a row at the top by linearly interpolating the + latitudinal progression. + + Summary of algorithm: + 1. Enlarge the latitude and longitude matrices by adding one row and one column to each side. + 2. Calculate the patch vertices by averaging every 2x2 square in the enlarged grid. We also + call these points the midpoints. + 3. By using the vertices of the patches, i.e. the midpoints, compute the areas of the patches. + + Args: + lat (:class:`torch.Tensor`): Latitude matrix. Must be decreasing along rows. + lon (:class:`torch.Tensor`): Longitude matrix. Must be increasing along columns. + + Returns: + :class:`torch.Tensor`: Areas in square kilometer. + """ + if not (lat.dim() == lon.dim() == 2): + raise ValueError("`lat` and `lon` must both be matrices.") + if lat.shape != lat.shape: + raise ValueError("`lat` and `lon` must have the same shape.") + + # Check that the latitude matrix is decreasing in the appropriate way. + if not torch.all(lat[1:] - lat[:-1] <= 0): + raise ValueError("`lat` must be decreasing along rows.") + + # Check that the longitude matrix is increasing in the appropriate way. + if not torch.all(lon[:, 1:] - lon[:, :-1] >= 0): + raise ValueError("`lon` must be increasing along columns.") + + # Enlarge the latitude and longitude matrices for the midpoint computation. + lat = expand_matrix(lat) + lon = expand_matrix(lon) + + # Latitudes cannot expand beyond the poles. + lat = torch.clamp(lat, -90, 90) + + # Calculate midpoints between entries in lat/lon. This is very important for symmetry of the + # resulting areas. + lat_midpoints = (lat[:-1, :-1] + lat[:-1, 1:] + lat[1:, :-1] + lat[1:, 1:]) / 4 + lon_midpoints = (lon[:-1, :-1] + lon[:-1, 1:] + lon[1:, :-1] + lon[1:, 1:]) / 4 + + # Determine squares and return the area of those squares. + top_left = torch.stack((lat_midpoints[1:, :-1], lon_midpoints[1:, :-1]), dim=-1) + top_right = torch.stack((lat_midpoints[1:, 1:], lon_midpoints[1:, 1:]), dim=-1) + bottom_left = torch.stack( + (lat_midpoints[:-1, :-1], lon_midpoints[:-1, :-1]), dim=-1 + ) + bottom_right = torch.stack((lat_midpoints[:-1, 1:], lon_midpoints[:-1, 1:]), dim=-1) + polygon = torch.stack((top_left, top_right, bottom_right, bottom_left), dim=-2) + + return area(polygon) diff --git a/rslearn/models/copernicusfm/aurora/fourier.py b/rslearn/models/copernicusfm/aurora/fourier.py new file mode 100644 index 00000000..4b96d7fb --- /dev/null +++ b/rslearn/models/copernicusfm/aurora/fourier.py @@ -0,0 +1,134 @@ +# type: ignore +"""Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" + +import math + +import numpy as np +import torch +import torch.nn as nn + +from .area import area, radius_earth + +__all__ = [ + "FourierExpansion", + "pos_expansion", + "scale_expansion", + "lead_time_expansion", + "levels_expansion", + "absolute_time_expansion", +] + + +class FourierExpansion(nn.Module): + """A Fourier series-style expansion into a high-dimensional space. + + Attributes: + lower (float): Lower wavelength. + upper (float): Upper wavelength. + assert_range (bool): Assert that the encoded tensor is within the specified wavelength + range. + """ + + def __init__(self, lower: float, upper: float, assert_range: bool = True) -> None: + """Initialise. + + Args: + lower (float): Lower wavelength. + upper (float): Upper wavelength. + assert_range (bool, optional): Assert that the encoded tensor is within the specified + wavelength range. Defaults to `True`. + """ + super().__init__() + self.lower = lower + self.upper = upper + self.assert_range = assert_range + + def forward(self, x: torch.Tensor, d: int) -> torch.Tensor: + """Perform the expansion. + + Adds a dimension of length `d` to the end of the shape of `x`. + + Args: + x (:class:`torch.Tensor`): Input to expand of shape `(..., n)`. All elements of `x` must + lie within `[self.lower, self.upper]` if `self.assert_range` is `True`. + d (int): Dimensionality. Must be a multiple of two. + + Raises: + AssertionError: If `self.assert_range` is `True` and not all elements of `x` are not + within `[self.lower, self.upper]`. + ValueError: If `d` is not a multiple of two. + + Returns: + torch.Tensor: Fourier series-style expansion of `x` of shape `(..., n, d)`. + """ + # If the input is not within the configured range, the embedding might be ambiguous! + in_range = torch.logical_and( + self.lower <= x.abs(), torch.all(x.abs() <= self.upper) + ) + in_range_or_zero = torch.all( + torch.logical_or(in_range, x == 0) + ) # Allow zeros to pass through. + if self.assert_range and not in_range_or_zero: + raise AssertionError( + f"The input tensor is not within the configured range" + f" `[{self.lower}, {self.upper}]`." + ) + + # We will use half of the dimensionality for `sin` and the other half for `cos`. + if not (d % 2 == 0): + raise ValueError("The dimensionality must be a multiple of two.") + + # Always perform the expansion with `float64`s to avoid numerical accuracy shenanigans. + x = x.double() + + wavelengths = torch.logspace( + math.log10(self.lower), + math.log10(self.upper), + d // 2, + base=10, + device=x.device, + dtype=x.dtype, + ) + prod = torch.einsum("...i,j->...ij", x, 2 * np.pi / wavelengths) + encoding = torch.cat((torch.sin(prod), torch.cos(prod)), dim=-1) + + return encoding.float() # Cast to `float32` to avoid incompatibilities. + + +# Determine a reasonable smallest value for the scale embedding by assuming a smallest delta in +# latitudes and longitudes. +_delta = 0.01 # Reasonable smallest delta in latitude and longitude +_min_patch_area: float = area( + torch.tensor( + [ + # The smallest patches will be at the poles. Just use the north pole. + [90, 0], + [90, _delta], + [90 - _delta, _delta], + [90 - _delta, 0], + ], + dtype=torch.float64, + ) +).item() +_area_earth = 4 * np.pi * radius_earth * radius_earth + +pos_expansion = FourierExpansion(_delta, 720) + + +scale_expansion = FourierExpansion(_min_patch_area, _area_earth) + + +lead_time_expansion = FourierExpansion(1 / 60, 24 * 7 * 3) + +levels_expansion = FourierExpansion(0.01, 1e5) + +absolute_time_expansion = FourierExpansion(1, 24 * 365.25, assert_range=False) + +### new for SSL4EO-S ### +# min wavelength: ultraviolet light (100 nm) +# max wavelength: radio waves (1 m) +spectrum_central_expansion = FourierExpansion(1e-7, 1) + +# min bandwidth: 10nm +# max bandwidth: 1m +spectrum_width_expansion = FourierExpansion(1e-7, 1) diff --git a/rslearn/models/copernicusfm/dynamic_hypernetwork.py b/rslearn/models/copernicusfm/dynamic_hypernetwork.py new file mode 100644 index 00000000..f2d5b121 --- /dev/null +++ b/rslearn/models/copernicusfm/dynamic_hypernetwork.py @@ -0,0 +1,529 @@ +# type: ignore +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init + +# CopernicusFM: meta encoding (follow aurora) +from .aurora.fourier import FourierExpansion + +# CopernicusFM: dynamic patch size (follow flexivit) +from .flexivit.patch_embed import pi_resize_patch_embed +from .util.pos_embed import get_1d_sincos_pos_embed_from_grid_torch + +# from torchvision.datasets.utils import download_url + + +random_seed = 1234 +torch.manual_seed(random_seed) + + +class TransformerWeightGenerator(nn.Module): + def __init__(self, input_dim, output_dim, embed_dim, num_heads=4, num_layers=1): + super(TransformerWeightGenerator, self).__init__() + encoder_layer = nn.TransformerEncoderLayer( + d_model=input_dim, + nhead=num_heads, + activation="gelu", + norm_first=False, + batch_first=False, + dropout=False, + ) + self.transformer_encoder = nn.TransformerEncoder( + encoder_layer, num_layers=num_layers, enable_nested_tensor=False + ) + + # Linear layer to map transformer output to desired weight shape + self.fc_weight = nn.Linear(input_dim, output_dim) + self.fc_bias = nn.Linear(input_dim, embed_dim) + self.wt_num = 128 + self.weight_tokens = nn.Parameter(torch.empty([self.wt_num, input_dim])) + self.bias_token = nn.Parameter(torch.empty([1, input_dim])) + + # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) + torch.nn.init.normal_(self.weight_tokens, std=0.02) + torch.nn.init.normal_(self.bias_token, std=0.02) + + def forward(self, x): + # x should have shape [seq_len, batch, input_dim] + pos_wave = x + x = torch.cat([self.weight_tokens, pos_wave], dim=0) + x = torch.cat([x, self.bias_token], dim=0) + transformer_output = self.transformer_encoder(x) + weights = self.fc_weight(transformer_output[self.wt_num : -1] + pos_wave) + bias = self.fc_bias( + transformer_output[-1] + ) # Using the last output to generate bias + return weights, bias + + +class GaussianFourierFeatureTransform(torch.nn.Module): + """An implementation of Gaussian Fourier feature mapping. + + "Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains": + https://arxiv.org/abs/2006.10739 + https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html + + Given an input of size [batches, num_input_channels, width, height], + returns a tensor of size [batches, mapping_size*2, width, height]. + """ + + def __init__(self, num_input_channels, mapping_size=256, scale=10): + super().__init__() + + self._num_input_channels = num_input_channels + self._mapping_size = mapping_size + torch.manual_seed(42) + self._B = torch.randn((num_input_channels, mapping_size)) * scale + + def forward(self, x): + assert x.dim() == 4, f"Expected 4D input (got {x.dim()}D input)" + + batches, channels, width, height = x.shape + + assert channels == self._num_input_channels, ( + f"Expected input to have {self._num_input_channels} channels (got {channels} channels)" + ) + + # Make shape compatible for matmul with _B. + # From [B, C, W, H] to [(B*W*H), C]. + x = x.permute(0, 2, 3, 1).reshape(batches * width * height, channels) + + x = x @ self._B.to(x.device) + + # From [(B*W*H), C] to [B, W, H, C] + x = x.view(batches, width, height, self._mapping_size) + # From [B, W, H, C] to [B, C, W, H] + x = x.permute(0, 3, 1, 2) + + x = 2 * np.pi * x + return torch.cat([torch.sin(x), torch.cos(x)], dim=1) + + +class Basic1d(nn.Module): + def __init__(self, in_channels, out_channels, bias=True): + super().__init__() + conv = nn.Linear(in_channels, out_channels, bias) + self.conv = nn.Sequential( + conv, + ) + if not bias: + self.conv.add_module("ln", nn.LayerNorm(out_channels)) + self.conv.add_module("relu", nn.ReLU(inplace=True)) + + def forward(self, x): + out = self.conv(x) + return out + + +class FCResLayer(nn.Module): + def __init__(self, linear_size=128): + super(FCResLayer, self).__init__() + self.l_size = linear_size + self.nonlin1 = nn.ReLU(inplace=True) + self.nonlin2 = nn.ReLU(inplace=True) + # self.dropout1 = nn.Dropout() + self.w1 = nn.Linear(self.l_size, self.l_size) + self.w2 = nn.Linear(self.l_size, self.l_size) + + def forward(self, x): + y = self.w1(x) + y = self.nonlin1(y) + # y = self.dropout1(y) + y = self.w2(y) + y = self.nonlin2(y) + out = x + y + return out + + +class Dynamic_MLP_Decoder(nn.Module): + def __init__(self, wv_planes, inter_dim=128, kernel_size=16, decoder_embed=512): + super().__init__() + self.kernel_size = kernel_size + self.wv_planes = wv_planes + self.inter_dim = inter_dim + self.decoder_embed = decoder_embed + self._num_kernel = self.kernel_size * self.kernel_size * self.decoder_embed + + # self.weight_generator = nn.Sequential(Basic1d(wv_planes, self.inter_dim, bias=True), + # nn.Linear(self.inter_dim, self._num_kernel)) + self.weight_generator = TransformerWeightGenerator( + wv_planes, self._num_kernel, decoder_embed + ) + self.scaler = 0.01 + + self._init_weights() + + def _get_weights(self, waves, batch=True): + dweights = [] + dynamic_weights = None + if batch: + dynamic_weights = self.weight_generator(waves) + else: + for i in range(waves.size(0)): + dweights.append(self.weight_generator(waves[i])) + dynamic_weights = torch.stack(dweights, dim=0) + + return dynamic_weights + + def weight_init(self, m): + if isinstance(m, nn.Linear): + init.xavier_uniform_(m.weight) + m.bias.data.fill_(0.01) + + def _init_weights(self): + """Initialize the base weights and dynamic mlp weights""" + self.weight_generator.apply(self.weight_init) + + def forward(self, img_feat, waves, kernel_size=None): + inplanes = waves.size(0) + # wv_feats: 9,128 -> 9*16*16,512 + weight, bias = self._get_weights(waves) # 9,16*16*512 + # dynamic_weight = weight.view( + # inplanes * self.kernel_size * self.kernel_size, self.decoder_embed + # ) # 9*16*16,512 + + # CopernicusFM: dynamic patch size + dynamic_weight = weight.view( + inplanes, self.kernel_size, self.kernel_size, self.decoder_embed + ) + dynamic_weight = dynamic_weight.permute([3, 0, 1, 2]) + # resize the weight to match different preferred kernel sizes + if kernel_size != None and self.kernel_size != kernel_size: + dynamic_weight = pi_resize_patch_embed( + dynamic_weight, (kernel_size, kernel_size) + ) # 512, 9, p, p + else: + kernel_size = self.kernel_size + dynamic_weight = ( + dynamic_weight.permute([1, 2, 3, 0]) + .contiguous() + .view(-1, self.decoder_embed) + ) # 9*p*p,512 + + weights = dynamic_weight * self.scaler + + dynamic_out = F.linear(img_feat, weights, bias=None) + x = dynamic_out + return x + + +class Dynamic_Patch_Embed(nn.Module): + """Input: channels of wavelength (normalized): List -> List + kernel size of the depth-wise convolution: kernel_size, default 3x3 + wv_planes + inplanes + """ + + def __init__(self, wv_planes, inter_dim=128, kernel_size=3, embed_dim=1024): + super().__init__() + self.kernel_size = kernel_size + self.wv_planes = wv_planes + self.embed_dim = embed_dim + self.kernel_size = kernel_size + self.patch_size = (kernel_size, kernel_size) + self.weight2 = nn.Parameter( + torch.empty([embed_dim, 2, kernel_size, kernel_size]) + ) + self.bias2 = nn.Parameter(torch.empty([embed_dim])) + self.weight3 = nn.Parameter( + torch.empty([embed_dim, 3, kernel_size, kernel_size]) + ) + self.bias3 = nn.Parameter(torch.empty([embed_dim])) + self.weight4 = nn.Parameter( + torch.empty([embed_dim, 4, kernel_size, kernel_size]) + ) + self.bias4 = nn.Parameter(torch.empty([embed_dim])) + self.weight9 = nn.Parameter( + torch.empty([embed_dim, 9, kernel_size, kernel_size]) + ) + self.bias9 = nn.Parameter(torch.empty([embed_dim])) + self.weight70 = nn.Parameter( + torch.empty([embed_dim, 70, kernel_size, kernel_size]) + ) + self.bias70 = nn.Parameter(torch.empty([embed_dim])) + self.weights = { + 2: self.weight2, + 3: self.weight3, + 4: self.weight4, + 9: self.weight9, + 70: self.weight70, + } + self.biass = { + 2: self.bias2, + 3: self.bias3, + 4: self.bias4, + 9: self.bias9, + 70: self.bias70, + } + + def forward(self, img_feat, waves): + inplanes = waves.size(0) + # wv_feats: 9,128 -> 9, 3x3x3 + weights = self.weights[inplanes] + bias = self.biass[inplanes] + + dynamic_out = F.conv2d( + img_feat, weights, bias=bias, stride=self.kernel_size, padding=1, dilation=1 + ) + + x = dynamic_out + x = x.flatten(2).transpose(1, 2) + + return x + + +class Dynamic_MLP_OFA(nn.Module): + """Input: channels of wavelength (normalized): List -> List + kernel size of the depth-wise convolution: kernel_size, default 3x3 + wv_planes + inplanes + """ + + def __init__(self, wv_planes, inter_dim=128, kernel_size=3, embed_dim=1024): + super().__init__() + self.kernel_size = kernel_size + self.wv_planes = wv_planes + self.embed_dim = embed_dim + self.kernel_size = kernel_size + self._num_kernel = self.kernel_size * self.kernel_size * self.embed_dim + self.inter_dim = inter_dim + self.patch_size = (kernel_size, kernel_size) + self.num_patches = -1 + + self.weight_generator = TransformerWeightGenerator( + wv_planes, self._num_kernel, embed_dim + ) + self.scaler = 0.01 + + self.fclayer = FCResLayer(wv_planes) + + self._init_weights() + + def _get_weights(self, waves): + dynamic_weights = self.weight_generator(waves) + return dynamic_weights + + def weight_init(self, m): + if isinstance(m, nn.Linear): + init.xavier_uniform_(m.weight) + m.bias.data.fill_(0.01) + + def _init_weights(self): + """Initialize the base weights and dynamic mlp weights""" + self.weight_generator.apply(self.weight_init) + self.fclayer.apply(self.weight_init) + + def forward(self, img_feat, wvs): + inplanes = wvs.size(0) + # wv_feats: 9,128 -> 9, 3x3x3 + waves = get_1d_sincos_pos_embed_from_grid_torch(self.wv_planes, wvs * 1000) + waves = self.fclayer(waves) + weight, bias = self._get_weights(waves) # 3x3x3 + # bias = None + + # dynamic_weight = weight.view(self.embed_dim, inplanes, self.kernel_size, self.kernel_size) #3xoutdx16x16 + dynamic_weight = weight.view( + inplanes, self.kernel_size, self.kernel_size, self.embed_dim + ) + dynamic_weight = dynamic_weight.permute([3, 0, 1, 2]) + if bias is not None: + bias = bias.view([self.embed_dim]) * self.scaler + + weights = dynamic_weight * self.scaler + + dynamic_out = F.conv2d( + img_feat, weights, bias=bias, stride=self.kernel_size, padding=1, dilation=1 + ) + + x = dynamic_out + x = x.flatten(2).transpose(1, 2) + + return x, waves + + +class Dynamic_MLP_OFA_spectral(nn.Module): + """Input: channels of wavelength and bandwidth (normalized): List -> List + kernel size of the depth-wise convolution: kernel_size, default 3x3 + wv_planes + inplanes + """ + + def __init__(self, wv_planes, inter_dim=128, kernel_size=3, embed_dim=1024): + super().__init__() + self.kernel_size = kernel_size + self.wv_planes = wv_planes + self.embed_dim = embed_dim + self.kernel_size = kernel_size + self._num_kernel = self.kernel_size * self.kernel_size * self.embed_dim + self.inter_dim = inter_dim + self.patch_size = (kernel_size, kernel_size) + self.num_patches = -1 + + ## CopernicusFM: fourier embedding for wavelength and bandwidth + # min wavelength: ultraviolet light (100 nm) + # max wavelength: radio waves (1 m) + self.spectrum_central_expansion = FourierExpansion(100, 1e9) + # min bandwidth: s2 ~ 10nm + # max bandwidth: s1 ~ 1m + self.spectrum_bandwidth_expansion = FourierExpansion(1, 1e9) + + self.weight_generator = TransformerWeightGenerator( + wv_planes, self._num_kernel, embed_dim + ) + self.scaler = 0.01 + + self.fclayer = FCResLayer(wv_planes) + + self._init_weights() + + def _get_weights(self, waves): + dynamic_weights = self.weight_generator(waves) + + return dynamic_weights + + def weight_init(self, m): + if isinstance(m, nn.Linear): + init.xavier_uniform_(m.weight) + m.bias.data.fill_(0.01) + + def _init_weights(self): + """Initialize the base weights and dynamic mlp weights""" + self.weight_generator.apply(self.weight_init) + self.fclayer.apply(self.weight_init) + + def forward(self, img_feat, wvs, bandwidths, kernel_size=None): + """wvs: nm + bandwidths: nm + """ + inplanes = wvs.size(0) + # wv_feats: 9,128 -> 9, 3x3x3 + # waves = get_1d_sincos_pos_embed_from_grid_torch(self.wv_planes, wvs * 1000) # dofa: fixed sincos pos embedding + # waves = get_1d_fourier_pos_embed_from_grid_torch(self.wv_planes, wvs * 1000) # new: fourier pos embedding + emb_central = self.spectrum_central_expansion(wvs, self.wv_planes) + emb_bandwidth = self.spectrum_bandwidth_expansion(bandwidths, self.wv_planes) + waves = ( + emb_central + emb_bandwidth + ) # simply add two embeddings, can be more complex later + + waves = self.fclayer(waves) + weight, bias = self._get_weights(waves) # 3x3x3 + + # Fix bug + dynamic_weight = weight.view( + inplanes, self.kernel_size, self.kernel_size, self.embed_dim + ) # 9, 3, 3, 1024 + dynamic_weight = dynamic_weight.permute([3, 0, 1, 2]) # 1024, 9, 3, 3 + # resize the weight to match different preferred kernel sizes + if kernel_size != None and self.kernel_size != kernel_size: + dynamic_weight = pi_resize_patch_embed( + dynamic_weight, (kernel_size, kernel_size) + ) + else: + kernel_size = self.kernel_size + + if bias is not None: + bias = bias.view([self.embed_dim]) * self.scaler + + weights = dynamic_weight * self.scaler + + dynamic_out = F.conv2d( + img_feat, weights, bias=bias, stride=kernel_size, padding=1, dilation=1 + ) + + x = dynamic_out + x = x.flatten(2).transpose(1, 2) + + return x, waves + + +class Dynamic_MLP_OFA_variable(nn.Module): + """Input: language embedding of variable name: Pytorch tensor + kernel size of the depth-wise convolution: kernel_size, default 3x3 + wv_planes + inplanes + """ + + def __init__(self, wv_planes, inter_dim=128, kernel_size=3, embed_dim=1024): + super().__init__() + self.kernel_size = kernel_size + self.wv_planes = wv_planes + self.embed_dim = embed_dim + self.kernel_size = kernel_size + self._num_kernel = self.kernel_size * self.kernel_size * self.embed_dim + self.inter_dim = inter_dim + self.patch_size = (kernel_size, kernel_size) + self.num_patches = -1 + + self.language_proj = nn.Linear( + 2048, self.wv_planes + ) # project to the same dimension as wv_planes + + self.weight_generator = TransformerWeightGenerator( + wv_planes, self._num_kernel, embed_dim + ) + self.scaler = 0.01 + + self.fclayer = FCResLayer(wv_planes) + + self._init_weights() + + def _get_weights(self, waves): + dynamic_weights = self.weight_generator(waves) + + return dynamic_weights + + def weight_init(self, m): + if isinstance(m, nn.Linear): + init.xavier_uniform_(m.weight) + m.bias.data.fill_(0.01) + + def _init_weights(self): + """Initialize the base weights and dynamic mlp weights""" + self.weight_generator.apply(self.weight_init) + self.fclayer.apply(self.weight_init) + + def forward(self, img_feat, language_embed, kernel_size=None): + """wvs: nm + bandwidths: nm + """ + # wv_feats: 9,128 -> 9, 3x3x3 + emb_language = language_embed.unsqueeze(0) + waves = self.language_proj(emb_language) + # print(waves.size()) + + waves = self.fclayer(waves) + # print(waves.size()) + weight, bias = self._get_weights(waves) # 3x3x3 + + # inplanes = wvs.size(0) + inplanes = waves.size(0) + # print(inplanes) + # Fix bug + dynamic_weight = weight.view( + inplanes, self.kernel_size, self.kernel_size, self.embed_dim + ) # 9, 3, 3, 1024 + dynamic_weight = dynamic_weight.permute([3, 0, 1, 2]) # 1024, 9, 3, 3 + + # resize the weight to match different preferred kernel sizes + if kernel_size != None and self.kernel_size != kernel_size: + dynamic_weight = pi_resize_patch_embed( + dynamic_weight, (kernel_size, kernel_size) + ) + else: + kernel_size = self.kernel_size + + if bias is not None: + bias = bias.view([self.embed_dim]) * self.scaler + + weights = dynamic_weight * self.scaler + + dynamic_out = F.conv2d( + img_feat, weights, bias=bias, stride=kernel_size, padding=1, dilation=1 + ) + + x = dynamic_out + x = x.flatten(2).transpose(1, 2) + + return x, waves diff --git a/rslearn/models/copernicusfm/flexivit/patch_embed.py b/rslearn/models/copernicusfm/flexivit/patch_embed.py new file mode 100644 index 00000000..8a53f9c1 --- /dev/null +++ b/rslearn/models/copernicusfm/flexivit/patch_embed.py @@ -0,0 +1,260 @@ +# type: ignore +from collections.abc import Sequence + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from functorch import vmap +from torch import Tensor + +from .utils import to_2tuple + + +def pi_resize_patch_embed( + patch_embed: Tensor, + new_patch_size: tuple[int, int], + interpolation: str = "bicubic", + antialias: bool = True, +): + """Resample patch embedding weights to a target resolution via pseudo-inverse + resizing. + + Based on: + https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py + https://arxiv.org/abs/2212.08013 + + Args: + patch_embed: Patch embedding parameters of size [d, c, h, w] + new_patch_size: Target [height, width] of embedding + interpolation: Resize interpolation type + antialias: Whether to apply antialiasing resizing + Returns: + Resized pos_embed of size [d, c h', w'] + """ + assert len(patch_embed.shape) == 4, "Patch embed kernel should be a 4D tensor" + assert len(new_patch_size) == 2, "New patch size should only be (height, width)" + + old_patch_size = tuple(patch_embed.shape[2:]) + + # Return original kernel if no resize is necessary + if old_patch_size == new_patch_size: + return patch_embed + + def resize(x: Tensor, shape: tuple[int, int]): + x_resized = F.interpolate( + x[None, None, ...], + shape, + mode=interpolation, + antialias=antialias, + ) + return x_resized[0, 0, ...] + + def calculate_pinv(old_shape: tuple[int, int], new_shape: tuple[int, int]): + mat = [] + for i in range(np.prod(old_shape)): + basis_vec = torch.zeros(old_shape) + basis_vec[np.unravel_index(i, old_shape)] = 1.0 + mat.append(resize(basis_vec, new_shape).reshape(-1)) + resize_matrix = torch.stack(mat) + return torch.linalg.pinv(resize_matrix) + + # Calculate pseudo-inverse of resize matrix + resize_matrix_pinv = calculate_pinv(old_patch_size, new_patch_size) + resize_matrix_pinv = resize_matrix_pinv.to(patch_embed.device) + + def resample_patch_embed(patch_embed: Tensor): + h, w = new_patch_size + resampled_kernel = resize_matrix_pinv @ patch_embed.reshape(-1) + return rearrange(resampled_kernel, "(h w) -> h w", h=h, w=w) + + v_resample_patch_embed = vmap(vmap(resample_patch_embed, 0, 0), 1, 1) + + return v_resample_patch_embed(patch_embed) + + +def interpolate_resize_patch_embed( + patch_embed: Tensor, + new_patch_size: tuple[int, int], + interpolation: str = "bicubic", + antialias: bool = True, +): + """Resample patch embedding weights to a target resolution via interpolation + + Args: + patch_embed: Patch embedding parameters of size [d, c, h, w] + new_patch_size: Target [height, width] of embedding + interpolation: Resize interpolation type + antialias: Whether to apply antialiasing resizing + Returns: + Resized pos_embed of size [d, c h', w'] + """ + assert len(patch_embed.shape) == 4, "Patch embed kernel should be a 4D tensor" + assert len(new_patch_size) == 2, "New patch size should only be (height, width)" + + patch_embed = F.interpolate( + patch_embed, new_patch_size, mode=interpolation, antialias=antialias + ) + + return patch_embed + + +class FlexiPatchEmbed(nn.Module): + def __init__( + self, + img_size: int | tuple[int, int] = 240, + patch_size: int | tuple[int, int] = 32, + grid_size: int | tuple[int, int] = 7, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: nn.Module | None = None, + flatten: bool = True, + bias: bool = True, + patch_size_seq: Sequence[int] = (8, 10, 12, 15, 16, 20, 24, 30, 40, 48), + patch_size_probs: Sequence[float] | None = None, + interpolation: str = "bicubic", + antialias: bool = True, + ) -> None: + """2D image to patch embedding w/ flexible patch sizes + Extended from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/patch_embed.py#L24 + + Args: + img_size: Input image size + patch_size: Base patch size. i.e the size of the parameter buffer + grid_size: Size of pos_embed buffer + in_chans: Number of input image channels + embed_dim: Network embedding dimension size + norm_layer: Optional normalization layer + flatten: Whether to flatten the spatial dimensions of the output + bias: Whether to use bias in convolution + patch_size_seq: List of patch sizes to randomly sample from + patch_size_probs: Optional list of probabilities to sample corresponding + patch_size_seq elements. If None, then uniform distribution is used + interpolation: Resize interpolation type + antialias: Whether to apply antialiasing resizing + """ + super().__init__() + + self.img_size = to_2tuple(img_size) + self.patch_size = to_2tuple(patch_size) + self.grid_size = to_2tuple(grid_size) + self.num_patches = self.grid_size[0] * self.grid_size[1] + + self.flatten = flatten + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=bias, + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + # Flexi specific attributes + self.interpolation = interpolation + self.antialias = antialias + + self.patch_size_seq = patch_size_seq + + if self.patch_size_seq: + if not patch_size_probs: + n = len(self.patch_size_seq) + self.patch_size_probs = [1.0 / n] * n + else: + self.patch_size_probs = [ + p / sum(patch_size_probs) for p in patch_size_probs + ] + else: + self.patch_size_probs = [] + + # Pre-calculate pinvs + self.pinvs = self._cache_pinvs() + + def _cache_pinvs(self) -> dict: + """Pre-calculate all pinv matrices""" + pinvs = {} + for ps in self.patch_size_seq: + ps = to_2tuple(ps) + pinvs[ps] = self._calculate_pinv(self.patch_size, ps) + return pinvs + + def _resize(self, x: Tensor, shape: tuple[int, int]) -> Tensor: + x_resized = F.interpolate( + x[None, None, ...], + shape, + mode=self.interpolation, + antialias=self.antialias, + ) + return x_resized[0, 0, ...] + + def _calculate_pinv( + self, old_shape: tuple[int, int], new_shape: tuple[int, int] + ) -> Tensor: + mat = [] + for i in range(np.prod(old_shape)): + basis_vec = torch.zeros(old_shape) + basis_vec[np.unravel_index(i, old_shape)] = 1.0 + mat.append(self._resize(basis_vec, new_shape).reshape(-1)) + resize_matrix = torch.stack(mat) + return torch.linalg.pinv(resize_matrix) + + def resize_patch_embed(self, patch_embed: Tensor, new_patch_size: tuple[int, int]): + """Resize patch_embed to target resolution via pseudo-inverse resizing""" + # Return original kernel if no resize is necessary + if self.patch_size == new_patch_size: + return patch_embed + + # Calculate pseudo-inverse of resize matrix + if new_patch_size not in self.pinvs: + self.pinvs[new_patch_size] = self._calculate_pinv( + self.patch_size, new_patch_size + ) + pinv = self.pinvs[new_patch_size] + pinv = pinv.to(patch_embed.device) + + def resample_patch_embed(patch_embed: Tensor): + h, w = new_patch_size + resampled_kernel = pinv @ patch_embed.reshape(-1) + return rearrange(resampled_kernel, "(h w) -> h w", h=h, w=w) + + v_resample_patch_embed = vmap(vmap(resample_patch_embed, 0, 0), 1, 1) + + return v_resample_patch_embed(patch_embed) + + def forward( + self, + x: Tensor, + patch_size: int | tuple[int, int] | None = None, + return_patch_size: bool = False, + ) -> Tensor | tuple[Tensor, tuple[int, int]]: + if not patch_size and not self.training: + # During evaluation use base patch size if not specified + patch_size = self.patch_size + elif not patch_size: + # During training choose uniformly at random if not specified + assert self.patch_size_seq, ( + "No patch size specified during forward and no patch_size_seq given to FlexiPatchEmbed" + ) + patch_size = np.random.choice(self.patch_size_seq, p=self.patch_size_probs) + + patch_size = to_2tuple(patch_size) + + # Resize conv weights + if patch_size == self.patch_size: + weight = self.proj.weight + else: + weight = self.resize_patch_embed(self.proj.weight, patch_size) + + # Apply conv with resized weights + x = F.conv2d(x, weight, bias=self.proj.bias, stride=patch_size) + + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + + x = self.norm(x) + + if return_patch_size: + return x, patch_size + + return x diff --git a/rslearn/models/copernicusfm/flexivit/utils.py b/rslearn/models/copernicusfm/flexivit/utils.py new file mode 100644 index 00000000..7e3ed0b5 --- /dev/null +++ b/rslearn/models/copernicusfm/flexivit/utils.py @@ -0,0 +1,69 @@ +# type: ignore +import collections.abc +import math +from itertools import repeat +from typing import Any + +import torch +import torch.nn.functional as F + + +def to_2tuple(x: Any) -> tuple: + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, 2)) + + +def resize_abs_pos_embed( + pos_embed: torch.Tensor, + new_size: tuple[int, int], + old_size: int | tuple[int, int] | None = None, + num_prefix_tokens: int = 1, + interpolation: str = "bicubic", + antialias: bool = True, +) -> torch.Tensor: + """Resize absolute position embeddings to a target resolution via interpolation + + Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/pos_embed.py + + Args: + pos_embed: Position embeddings tensor of size [b, n, d] + new_size: Target [height, width] of embedding + old_size: Original [height, width] of embedding + num_prefix_tokens: Number of non-spatial prefix tokens (eg. cls) + interpolation: Resize interpolation type + antialias: Whether to apply antialiasing resizing + Returns: + Resized pos_embed of size [b, n', d] + """ + new_size = to_2tuple(new_size) + new_ntok = new_size[0] * new_size[1] + + if not old_size: + old_size = int(math.sqrt(pos_embed.shape[1] - num_prefix_tokens)) # type:ignore + old_size = to_2tuple(old_size) + + # Return if no resize necessary + if new_size == old_size: + return pos_embed + + if num_prefix_tokens: + posemb_prefix, pos_embed = ( + pos_embed[:, :num_prefix_tokens], + pos_embed[:, num_prefix_tokens:], + ) + else: + posemb_prefix, pos_embed = None, pos_embed + + # Interpolate position embedding + pos_embed = pos_embed.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2) + pos_embed = F.interpolate( + pos_embed, size=new_size, mode=interpolation, antialias=antialias + ) + pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(1, new_ntok, -1) + + # Add back extra prefix tokens + if posemb_prefix is not None: + pos_embed = torch.cat([posemb_prefix, pos_embed], dim=1) + + return pos_embed diff --git a/rslearn/models/copernicusfm/model_vit.py b/rslearn/models/copernicusfm/model_vit.py new file mode 100644 index 00000000..09d519c1 --- /dev/null +++ b/rslearn/models/copernicusfm/model_vit.py @@ -0,0 +1,348 @@ +# type: ignore +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- + +import math +from functools import partial + +import torch +import torch.nn as nn +from timm.models.vision_transformer import Block + +from .aurora.fourier import FourierExpansion +from .dynamic_hypernetwork import Dynamic_MLP_OFA_spectral, Dynamic_MLP_OFA_variable +from .flexivit.utils import resize_abs_pos_embed + + +class CopernicusFMViT(nn.Module): + """CopernicusFM: VisionTransformer backbone""" + + def __init__( + self, + img_size=224, + patch_size=16, + drop_rate=0.0, + embed_dim=1024, + depth=24, + num_heads=16, + wv_planes=128, + num_classes=0, + global_pool=True, + mlp_ratio=4.0, + norm_layer=nn.LayerNorm, + loc_option="lonlat", + return_intermediate=False, + intermediate_indices=None, + ): + super().__init__() + + self.wv_planes = wv_planes + self.global_pool = global_pool + if self.global_pool: + norm_layer = norm_layer + embed_dim = embed_dim + self.fc_norm = norm_layer(embed_dim) + else: + self.norm = norm_layer(embed_dim) + + self.patch_embed_spectral = Dynamic_MLP_OFA_spectral( + wv_planes=128, inter_dim=128, kernel_size=16, embed_dim=embed_dim + ) + self.patch_embed_variable = Dynamic_MLP_OFA_variable( + wv_planes=128, inter_dim=128, kernel_size=16, embed_dim=embed_dim + ) + + self.num_patches = (img_size // patch_size) ** 2 + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + # --------------------------------------------------------------------------- + + self.pos_embed = nn.Parameter( + torch.zeros(1, self.num_patches + 1, embed_dim), requires_grad=False + ) # fixed sin-cos embedding + + self.loc_option = loc_option + if loc_option == "cartesian": + self.coord_expansion = FourierExpansion(1e-7, 2) + elif loc_option == "lonlat": + self.coord_expansion = FourierExpansion(0.0001, 720) + + self.scale_expansion = FourierExpansion(0.001, 5.1e8) # 1m2 to 5.1e8 km2 + self.time_expansion = FourierExpansion( + 1, 365.25, assert_range=False + ) # 1 to 365.25 days, enable more than 1 year + self.coord_fc = nn.Linear(embed_dim, embed_dim) + self.scale_fc = nn.Linear(embed_dim, embed_dim) + self.time_fc = nn.Linear(embed_dim, embed_dim) + # if meta info is not available, set to a learned parameter + self.coord_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.scale_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.time_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + + self.blocks = nn.ModuleList( + [ + Block( + embed_dim, + num_heads, + mlp_ratio, + qkv_bias=True, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) + + self.head_drop = nn.Dropout(drop_rate) + self.head = ( + nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + ) + + self.return_intermediate = return_intermediate + self.intermediate_indices = intermediate_indices + + def get_coord_pos_embed(self, lons, lats, embed_dim): + if self.loc_option == "cartesian": + # convert to spherical coordinates + spherical_x = ( + torch.cos(lons * math.pi / 180) * torch.cos(lats * math.pi / 180) + + 1 + + 1e-7 + ) + spherical_y = ( + torch.sin(lons * math.pi / 180) * torch.cos(lats * math.pi / 180) + + 1 + + 1e-7 + ) + spherical_z = torch.sin(lats * math.pi / 180) + 1 + 1e-7 + coord_embed_spherical_x = self.coord_expansion(spherical_x, embed_dim // 3) + coord_embed_spherical_y = self.coord_expansion(spherical_y, embed_dim // 3) + coord_embed_spherical_z = self.coord_expansion(spherical_z, embed_dim // 3) + coord_embed = torch.cat( + [ + coord_embed_spherical_x, + coord_embed_spherical_y, + coord_embed_spherical_z, + ], + dim=-1, + ) # [B,D] + elif self.loc_option == "lonlat": + coord_embed_lon = self.coord_expansion(lons + 180, embed_dim // 2) + coord_embed_lat = self.coord_expansion(lats + 90, embed_dim // 2) + coord_embed = torch.cat([coord_embed_lon, coord_embed_lat], dim=-1) + + if coord_embed.shape[-1] < embed_dim: + # pad zeros + coord_embed = torch.cat( + ( + coord_embed, + torch.zeros( + coord_embed.shape[0], + embed_dim - coord_embed.shape[-1], + device=coord_embed.device, + ), + ), + dim=-1, + ) + + return coord_embed.unsqueeze(1) # [B,1,D] + + def get_area_pos_embed(self, areas, embed_dim): + scale_embed = self.scale_expansion(areas, embed_dim) # B, D + return scale_embed.unsqueeze(1) # [B,1,D] + + def get_time_pos_embed(self, times, embed_dim): + time_embed = self.time_expansion(times, embed_dim) # B, D + return time_embed.unsqueeze(1) # [B,1,D] + + def forward_features( + self, + x, + meta_info, + wave_list, + bandwidth, + language_embed, + input_mode, + kernel_size=None, + ): + # embed patches + if input_mode == "spectral": + wavelist = torch.tensor(wave_list, device=x.device).float() + bandwidths = torch.tensor(bandwidth, device=x.device).float() + self.waves = wavelist + x, _ = self.patch_embed_spectral(x, self.waves, bandwidths, kernel_size) + elif input_mode == "variable": + x, _ = self.patch_embed_variable(x, language_embed, kernel_size) + + # resize pos embed + num_patches = x.size(1) + num_patches_sqrt = int(math.sqrt(num_patches)) + num_patches_sqrt_origin = int(math.sqrt(self.num_patches)) + pos_embed = resize_abs_pos_embed( + self.pos_embed, + num_patches_sqrt, + (num_patches_sqrt_origin, num_patches_sqrt_origin), + num_prefix_tokens=1, + ) + + # coord, scale and time pos embed + lons, lats, times, areas = ( + meta_info[:, 0], + meta_info[:, 1], + meta_info[:, 2], + meta_info[:, 3], + ) + embed_dim = pos_embed.shape[-1] + if torch.isnan(lons).any() or torch.isnan(lats).any(): + coord_embed = self.coord_token + else: + coord_embed = self.get_coord_pos_embed(lons, lats, embed_dim) + coord_embed = self.coord_fc(coord_embed) + if torch.isnan(areas).any(): + area_embed = self.scale_token + else: + area_embed = self.get_area_pos_embed(areas, embed_dim) + area_embed = self.scale_fc(area_embed) + if torch.isnan(times).any(): + time_embed = self.time_token + else: + time_embed = self.get_time_pos_embed(times, embed_dim) + time_embed = self.time_fc(time_embed) + pos_embed = pos_embed + coord_embed + area_embed + time_embed + + # add pos embed w/o cls token + x = x + pos_embed[:, 1:, :] + + # append cls token + cls_token = self.cls_token + pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + intermediate_features = [] + hw = num_patches_sqrt + hw_shape = (hw, hw) + + # apply Transformer blocks + for i, block in enumerate(self.blocks): + x = block(x) + if self.return_intermediate and (i in self.intermediate_indices): + out = x[:, 1:] + B, _, C = out.shape + out = ( + out.reshape(B, hw_shape[0], hw_shape[1], C) + .permute(0, 3, 1, 2) + .contiguous() + ) + intermediate_features.append(out) + + # if self.global_pool: + # x = x[:, 1:, :].mean(dim=1) # global pool without cls token + # outcome = self.fc_norm(x) + # else: + # x = self.norm(x) + # outcome = x[:, 0] + + # if self.return_intermediate: + # return outcome, intermediate_features + + # for segmentation tasks, ignore the norm + # https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/foundation_models/CopernicusFM/models_dwv_seg.py + # for classification, we will apply the fc_norm in the wrapper + return x[:, 1:, :] + + def forward_head(self, x, pre_logits=False): + x = self.head_drop(x) + return x if pre_logits else self.head(x) + + def forward( + self, + x, + meta_info, + wave_list, + bandwidth, + language_embed, + input_mode, + kernel_size=None, + ): + if self.return_intermediate: + x, intermediate_features = self.forward_features( + x, + meta_info, + wave_list, + bandwidth, + language_embed, + input_mode, + kernel_size, + ) + return x, intermediate_features + else: + fx = self.forward_features( + x, + meta_info, + wave_list, + bandwidth, + language_embed, + input_mode, + kernel_size, + ) + x = self.forward_head(fx) + return x, fx + + +def vit_small_patch16(**kwargs): + model = CopernicusFMViT( + patch_size=16, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def vit_base_patch16(**kwargs): + model = CopernicusFMViT( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def vit_large_patch16(**kwargs): + model = CopernicusFMViT( + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def vit_huge_patch14(**kwargs): + model = CopernicusFMViT( + patch_size=14, + embed_dim=1280, + depth=32, + num_heads=16, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model diff --git a/rslearn/models/copernicusfm/util/lr_sched.py b/rslearn/models/copernicusfm/util/lr_sched.py new file mode 100644 index 00000000..0044c160 --- /dev/null +++ b/rslearn/models/copernicusfm/util/lr_sched.py @@ -0,0 +1,29 @@ +# type: ignore +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + + +def adjust_learning_rate(optimizer, epoch, args): + """Decay the learning rate with half-cycle cosine after warmup""" + if epoch < args.warmup_epochs: + lr = args.lr * epoch / args.warmup_epochs + else: + lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * ( + 1.0 + + math.cos( + math.pi + * (epoch - args.warmup_epochs) + / (args.epochs - args.warmup_epochs) + ) + ) + for param_group in optimizer.param_groups: + if "lr_scale" in param_group: + param_group["lr"] = lr * param_group["lr_scale"] + else: + param_group["lr"] = lr + return lr diff --git a/rslearn/models/copernicusfm/util/misc.py b/rslearn/models/copernicusfm/util/misc.py new file mode 100644 index 00000000..a2f59abb --- /dev/null +++ b/rslearn/models/copernicusfm/util/misc.py @@ -0,0 +1,404 @@ +# type: ignore + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import builtins +import datetime +import os +import time +from collections import defaultdict, deque +from pathlib import Path + +import torch +import torch.distributed as dist + +# from torch._six import inf +from torch import inf + + +class SmoothedValue: + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """Warning: does not synchronize the deque!""" + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) + + +class MetricLogger: + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{attr}'" + ) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append(f"{name}: {str(meter)}") + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + log_msg = [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + if torch.cuda.is_available(): + log_msg.append("max mem: {memory:.0f}") + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print( + f"{header} Total time: {total_time_str} ({total_time / len(iterable):.4f} s / it)" + ) + + +def setup_for_distributed(is_master): + """This function disables printing when not in master process""" + builtin_print = builtins.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + force = force or (get_world_size() > 16) + if is_master or force: + now = datetime.datetime.now().time() + builtin_print(f"[{now}] ", end="") # print with time stamp + builtin_print(*args, **kwargs) + + builtins.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if args.dist_on_itp: + args.rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) + args.world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) + args.gpu = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) + args.dist_url = "tcp://%s:%s" % ( + os.environ["MASTER_ADDR"], + os.environ["MASTER_PORT"], + ) + os.environ["LOCAL_RANK"] = str(args.gpu) + os.environ["RANK"] = str(args.rank) + os.environ["WORLD_SIZE"] = str(args.world_size) + # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] + elif "RANK" in os.environ and "WORLD_SIZE" in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) + if "SLURM_LOCALID" in os.environ: + local_rank = int(os.environ["SLURM_LOCALID"]) + # If only one GPU is visible, force to 0 + if torch.cuda.device_count() == 1: + args.gpu = 0 + else: + args.gpu = local_rank + else: + args.gpu = args.rank % torch.cuda.device_count() + args.world_size = int(os.environ["SLURM_NNODES"]) * int( + os.environ["SLURM_TASKS_PER_NODE"][0] + ) + else: + print("Not using distributed mode") + setup_for_distributed(is_master=True) # hack + args.distributed = False + return + + args.distributed = True + + print("procid", int(os.environ.get("SLURM_PROCID"))) + print("localid", int(os.environ.get("SLURM_LOCALID"))) + print("devicecount", torch.cuda.device_count()) + print("CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES")) + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print( + f"World Size {args.world_size} | distributed init (rank {args.rank}): {args.dist_url}, gpu {args.gpu}", + flush=True, + ) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + print("distributed init done") + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + print("setup print done") + + +class NativeScalerWithGradNormCount: + state_dict_key = "amp_scaler" + + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__( + self, + loss, + optimizer, + clip_grad=None, + parameters=None, + create_graph=False, + update_grad=True, + ): + self._scaler.scale(loss).backward(create_graph=create_graph) + if update_grad: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_( + optimizer + ) # unscale the gradients of optimizer's assigned params in-place + norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + else: + self._scaler.unscale_(optimizer) + norm = get_grad_norm_(parameters) + self._scaler.step(optimizer) + self._scaler.update() + else: + norm = None + return norm + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) + + +def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.0) + device = parameters[0].grad.device + if norm_type == inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm( + torch.stack( + [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters] + ), + norm_type, + ) + return total_norm + + +def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): + output_dir = Path(args.output_dir) + epoch_name = str(epoch) + if loss_scaler is not None: + checkpoint_paths = [output_dir / ("checkpoint-%s.pth" % epoch_name)] + for checkpoint_path in checkpoint_paths: + to_save = { + "model": model_without_ddp.state_dict(), + "optimizer": optimizer.state_dict(), + "epoch": epoch, + "scaler": loss_scaler.state_dict(), + "args": args, + } + + save_on_master(to_save, checkpoint_path) + else: + client_state = {"epoch": epoch} + model.save_checkpoint( + save_dir=args.output_dir, + tag="checkpoint-%s" % epoch_name, + client_state=client_state, + ) + + +def load_model(args, model_without_ddp, optimizer, loss_scaler): + if args.resume: + if args.resume.startswith("https"): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location="cpu", check_hash=True + ) + else: + checkpoint = torch.load(args.resume, map_location="cpu") + model_without_ddp.load_state_dict(checkpoint["model"]) + print("Resume checkpoint %s" % args.resume) + if ( + "optimizer" in checkpoint + and "epoch" in checkpoint + and not (hasattr(args, "eval") and args.eval) + ): + optimizer.load_state_dict(checkpoint["optimizer"]) + args.start_epoch = checkpoint["epoch"] + 1 + if "scaler" in checkpoint: + loss_scaler.load_state_dict(checkpoint["scaler"]) + print("With optim & sched!") + + +def all_reduce_mean(x): + world_size = get_world_size() + if world_size > 1: + x_reduce = torch.tensor(x).cuda() + dist.all_reduce(x_reduce) + x_reduce /= world_size + return x_reduce.item() + else: + return x diff --git a/rslearn/models/copernicusfm/util/pos_embed.py b/rslearn/models/copernicusfm/util/pos_embed.py new file mode 100644 index 00000000..7a5035f2 --- /dev/null +++ b/rslearn/models/copernicusfm/util/pos_embed.py @@ -0,0 +1,216 @@ +# type: ignore +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + +import numpy as np +import torch + + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + # omega = np.arange(embed_dim // 2, dtype=np.float) # numpy deprecated in 1.20 + omega = np.arange(embed_dim // 2, dtype=float) + + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if "pos_embed" in checkpoint_model: + pos_embed_checkpoint = checkpoint_model["pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print( + "Position interpolate from %dx%d to %dx%d" + % (orig_size, orig_size, new_size, new_size) + ) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape( + -1, orig_size, orig_size, embedding_size + ).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, + size=(new_size, new_size), + mode="bicubic", + align_corners=False, + ) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model["pos_embed"] = new_pos_embed + + +def interpolate_pos_embed_ofa(model, checkpoint_model): + if "pos_embed" in checkpoint_model: + pos_embed_dict = checkpoint_model["pos_embed"] + + for key, pos_embed in pos_embed_dict.items(): + pos_embed_checkpoint = pos_embed + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed[key].num_patches + num_extra_tokens = model.pos_embed[key].shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print( + "Position interpolate from %dx%d to %dx%d" + % (orig_size, orig_size, new_size, new_size) + ) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape( + -1, orig_size, orig_size, embedding_size + ).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, + size=(new_size, new_size), + mode="bicubic", + align_corners=False, + ) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model["pos_embed"][key] = new_pos_embed + + +def get_2d_sincos_pos_embed_with_resolution( + embed_dim, grid_size, res, cls_token=False, device="cpu" +): + """grid_size: int of the grid height and width + res: array of size n, representing the resolution of a pixel (say, in meters), + + Return: + pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + # res = torch.FloatTensor(res).to(device) + res = res.to(device) + grid_h = torch.arange(grid_size, dtype=torch.float32, device=device) + grid_w = torch.arange(grid_size, dtype=torch.float32, device=device) + grid = torch.meshgrid( + grid_w, grid_h, indexing="xy" + ) # here h goes first,direction reversed for numpy + grid = torch.stack(grid, dim=0) # 2 x h x w + + # grid = grid.reshape([2, 1, grid_size, grid_size]) + grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w + _, n, h, w = grid.shape + pos_embed = get_2d_sincos_pos_embed_from_grid_torch( + embed_dim, grid + ) # # (nxH*W, D/2) + pos_embed = pos_embed.reshape(n, h * w, embed_dim) + if cls_token: + pos_embed = torch.cat( + [ + torch.zeros( + [n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device + ), + pos_embed, + ], + dim=1, + ) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid_torch( + embed_dim // 2, grid[0] + ) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid_torch( + embed_dim // 2, grid[1] + ) # (H*W, D/2) + + emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos): + """embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + old_shape = pos + omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb From 5b94d895df00a4be0991c0b0812b19b70ae6d082 Mon Sep 17 00:00:00 2001 From: hgherzog Date: Thu, 25 Sep 2025 19:59:46 +0000 Subject: [PATCH 04/14] update path --- rslearn/models/copernicusfm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rslearn/models/copernicusfm.py b/rslearn/models/copernicusfm.py index ac313fed..dce2a8de 100644 --- a/rslearn/models/copernicusfm.py +++ b/rslearn/models/copernicusfm.py @@ -9,7 +9,7 @@ from einops import rearrange from upath import UPath -from .copernicusfm.src.model_vit import vit_base_patch16 +from .copernicusfm.model_vit import vit_base_patch16 logger = logging.getLogger(__name__) From 6a5edb9ea46b8f27728de35e0d165d4955b6d278 Mon Sep 17 00:00:00 2001 From: hgherzog Date: Thu, 25 Sep 2025 20:10:13 +0000 Subject: [PATCH 05/14] rename for name space conflict --- rslearn/models/{copernicusfm => copernicusfm_src}/__init__.py | 0 rslearn/models/{copernicusfm => copernicusfm_src}/aurora/area.py | 0 .../models/{copernicusfm => copernicusfm_src}/aurora/fourier.py | 0 .../{copernicusfm => copernicusfm_src}/dynamic_hypernetwork.py | 0 .../{copernicusfm => copernicusfm_src}/flexivit/patch_embed.py | 0 .../models/{copernicusfm => copernicusfm_src}/flexivit/utils.py | 0 rslearn/models/{copernicusfm => copernicusfm_src}/model_vit.py | 0 .../models/{copernicusfm => copernicusfm_src}/util/lr_sched.py | 0 rslearn/models/{copernicusfm => copernicusfm_src}/util/misc.py | 0 .../models/{copernicusfm => copernicusfm_src}/util/pos_embed.py | 0 10 files changed, 0 insertions(+), 0 deletions(-) rename rslearn/models/{copernicusfm => copernicusfm_src}/__init__.py (100%) rename rslearn/models/{copernicusfm => copernicusfm_src}/aurora/area.py (100%) rename rslearn/models/{copernicusfm => copernicusfm_src}/aurora/fourier.py (100%) rename rslearn/models/{copernicusfm => copernicusfm_src}/dynamic_hypernetwork.py (100%) rename rslearn/models/{copernicusfm => copernicusfm_src}/flexivit/patch_embed.py (100%) rename rslearn/models/{copernicusfm => copernicusfm_src}/flexivit/utils.py (100%) rename rslearn/models/{copernicusfm => copernicusfm_src}/model_vit.py (100%) rename rslearn/models/{copernicusfm => copernicusfm_src}/util/lr_sched.py (100%) rename rslearn/models/{copernicusfm => copernicusfm_src}/util/misc.py (100%) rename rslearn/models/{copernicusfm => copernicusfm_src}/util/pos_embed.py (100%) diff --git a/rslearn/models/copernicusfm/__init__.py b/rslearn/models/copernicusfm_src/__init__.py similarity index 100% rename from rslearn/models/copernicusfm/__init__.py rename to rslearn/models/copernicusfm_src/__init__.py diff --git a/rslearn/models/copernicusfm/aurora/area.py b/rslearn/models/copernicusfm_src/aurora/area.py similarity index 100% rename from rslearn/models/copernicusfm/aurora/area.py rename to rslearn/models/copernicusfm_src/aurora/area.py diff --git a/rslearn/models/copernicusfm/aurora/fourier.py b/rslearn/models/copernicusfm_src/aurora/fourier.py similarity index 100% rename from rslearn/models/copernicusfm/aurora/fourier.py rename to rslearn/models/copernicusfm_src/aurora/fourier.py diff --git a/rslearn/models/copernicusfm/dynamic_hypernetwork.py b/rslearn/models/copernicusfm_src/dynamic_hypernetwork.py similarity index 100% rename from rslearn/models/copernicusfm/dynamic_hypernetwork.py rename to rslearn/models/copernicusfm_src/dynamic_hypernetwork.py diff --git a/rslearn/models/copernicusfm/flexivit/patch_embed.py b/rslearn/models/copernicusfm_src/flexivit/patch_embed.py similarity index 100% rename from rslearn/models/copernicusfm/flexivit/patch_embed.py rename to rslearn/models/copernicusfm_src/flexivit/patch_embed.py diff --git a/rslearn/models/copernicusfm/flexivit/utils.py b/rslearn/models/copernicusfm_src/flexivit/utils.py similarity index 100% rename from rslearn/models/copernicusfm/flexivit/utils.py rename to rslearn/models/copernicusfm_src/flexivit/utils.py diff --git a/rslearn/models/copernicusfm/model_vit.py b/rslearn/models/copernicusfm_src/model_vit.py similarity index 100% rename from rslearn/models/copernicusfm/model_vit.py rename to rslearn/models/copernicusfm_src/model_vit.py diff --git a/rslearn/models/copernicusfm/util/lr_sched.py b/rslearn/models/copernicusfm_src/util/lr_sched.py similarity index 100% rename from rslearn/models/copernicusfm/util/lr_sched.py rename to rslearn/models/copernicusfm_src/util/lr_sched.py diff --git a/rslearn/models/copernicusfm/util/misc.py b/rslearn/models/copernicusfm_src/util/misc.py similarity index 100% rename from rslearn/models/copernicusfm/util/misc.py rename to rslearn/models/copernicusfm_src/util/misc.py diff --git a/rslearn/models/copernicusfm/util/pos_embed.py b/rslearn/models/copernicusfm_src/util/pos_embed.py similarity index 100% rename from rslearn/models/copernicusfm/util/pos_embed.py rename to rslearn/models/copernicusfm_src/util/pos_embed.py From 03ba136f334bb7248a0288ade9af78ffc8229fd2 Mon Sep 17 00:00:00 2001 From: hgherzog Date: Thu, 25 Sep 2025 20:23:26 +0000 Subject: [PATCH 06/14] add test --- rslearn/models/copernicusfm.py | 2 +- tests/unit/models/test_copernicusfm.py | 34 ++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) create mode 100644 tests/unit/models/test_copernicusfm.py diff --git a/rslearn/models/copernicusfm.py b/rslearn/models/copernicusfm.py index dce2a8de..ef7829bd 100644 --- a/rslearn/models/copernicusfm.py +++ b/rslearn/models/copernicusfm.py @@ -9,7 +9,7 @@ from einops import rearrange from upath import UPath -from .copernicusfm.model_vit import vit_base_patch16 +from .copernicusfm_src.model_vit import vit_base_patch16 logger = logging.getLogger(__name__) diff --git a/tests/unit/models/test_copernicusfm.py b/tests/unit/models/test_copernicusfm.py new file mode 100644 index 00000000..3ad51679 --- /dev/null +++ b/tests/unit/models/test_copernicusfm.py @@ -0,0 +1,34 @@ +"""Test Copernicus FM model.""" + +import pathlib +import tempfile +from typing import Any + +import torch + +from rslearn.models.copernicusfm import CopernicusFM, CopernicusFMModality +import pytest + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +@pytest.skipif(not torch.cuda.is_available(), reason="Requires a GPU") +def test_copernicusfm() -> None: + """Verify that the forward pass for CROMA works.""" + input_hw = 32 + # We override the temporary directory so we don't retain the model weights outside + # of this test. + + band_order = { + CopernicusFMModality.SENTINEL2_L2A.value: ["B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B11", "B12"], + CopernicusFMModality.SENTINEL1.value: ["vv", "vh"], + } + inputs = [ + { + "sentinel2": torch.zeros((len(band_order[CopernicusFMModality.SENTINEL2_L2A.value]), input_hw, input_hw), dtype=torch.float32, device=DEVICE), + "sentinel1": torch.zeros((len(band_order[CopernicusFMModality.SENTINEL1.value]), input_hw, input_hw), dtype=torch.float32, device=DEVICE), + } + ] + with torch.no_grad(): + copernicusfm = CopernicusFM(band_order=band_order).to(DEVICE) + feature_list = copernicusfm(inputs) + assert feature_list[0].shape == torch.Size([1, 768, 14, 14]) and len(feature_list) == 1 \ No newline at end of file From c261d5fca5f55ba195a5dde40da925013e1aec29 Mon Sep 17 00:00:00 2001 From: hgherzog Date: Thu, 25 Sep 2025 17:13:20 -0400 Subject: [PATCH 07/14] linting --- .pre-commit-config.yaml | 24 +++++---- rslearn/models/copernicusfm.py | 54 +++++++++++-------- rslearn/models/copernicusfm_src/__init__.py | 2 +- .../copernicusfm_src/dynamic_hypernetwork.py | 2 +- rslearn/models/copernicusfm_src/model_vit.py | 2 +- rslearn/models/unet.py | 8 ++- tests/unit/models/test_copernicusfm.py | 44 +++++++++++---- 7 files changed, 90 insertions(+), 46 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0773e6aa..e7c1083b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,12 +45,14 @@ repos: additional_dependencies: - "pydantic>=2.7.1,<3" - "types-protobuf" - exclude: ./.*_pb2_.*.py + exclude: | + (?x)^(?:.*_pb2_.*\.py|rslearn/models/copernicusfm_src/.*)$ + - repo: https://github.com/PyCQA/bandit rev: "1.8.6" hooks: - id: bandit - exclude: ^tests/ + exclude: ^tests/|^rslearn/models/copernicusfm_src/ args: - -s # Skip B113 request_without_timeout because it has false positives e.g. when @@ -64,17 +66,19 @@ repos: entry: interrogate types: [python] args: - [ - --ignore-init-method, - --ignore-init-module, - -p, - -vv, - rslearn, - --fail-under=80, - ] + - --ignore-init-method + - --ignore-init-module + - --exclude=rslearn/models/copernicusfm_src/ + - -p + - -vv + - rslearn + - --fail-under=80 + - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.12.9 hooks: - id: ruff-check args: [ --fix ] + exclude: ^rslearn/models/copernicusfm_src/ - id: ruff-format + exclude: ^rslearn/models/copernicusfm_src/ diff --git a/rslearn/models/copernicusfm.py b/rslearn/models/copernicusfm.py index ef7829bd..283afe4f 100644 --- a/rslearn/models/copernicusfm.py +++ b/rslearn/models/copernicusfm.py @@ -1,8 +1,8 @@ """Copernicus FM model.""" -from enum import Enum import logging import math +from enum import Enum import torch import torch.nn.functional as F @@ -11,13 +11,16 @@ from .copernicusfm_src.model_vit import vit_base_patch16 - logger = logging.getLogger(__name__) + + class CopernicusFMModality(Enum): """Modality for Copernicus FM.""" + SENTINEL2_L2A = "sentinel2_l2a" SENTINEL1 = "sentinel1" + MODALITY_TO_WAVELENGTH_BANDWIDTHS: dict[str, dict[str, list]] = { # https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/configs/dataset/cobench_eurosat_s2.yaml CopernicusFMModality.SENTINEL2_L2A.value: { @@ -68,14 +71,21 @@ class CopernicusFM(torch.nn.Module): image_resolution = 224 patch_size = 16 input_mode = "spectral" - supported_modalities = [CopernicusFMModality.SENTINEL2_L2A.value, CopernicusFMModality.SENTINEL1.value] + # Don't need this as band order is provided + supported_modalities = [ + CopernicusFMModality.SENTINEL2_L2A.value, + CopernicusFMModality.SENTINEL1.value, + ] def __init__( - self, band_order: dict[str, list[str]], load_directory: str = "/weka/dfive-default/helios/models/copernicusfm" + self, + band_order: dict[str, list[str]], + load_directory: str = "/weka/dfive-default/helios/models/copernicusfm", ) -> None: """Initialize the Copernicus FM wrapper. Args: + band_order: The band order for each modality load_directory: The directory to load from """ super().__init__() @@ -84,7 +94,7 @@ def __init__( self.band_order = band_order self.model = vit_base_patch16(num_classes=10, global_pool=True) check_point = torch.load( - UPath(load_directory) / "CopernicusFM_ViT_base_varlang_e100.pth" + UPath(load_directory) / "CopernicusFM_ViT_base_varlang_e100.pth" # nosec B614 ) if "model" in check_point: state_dict = check_point["model"] @@ -97,14 +107,15 @@ def __init__( self.modality_to_wavelength_bandwidths = {} for modality in self.supported_modalities: wavelength_bandwidths = MODALITY_TO_WAVELENGTH_BANDWIDTHS[modality] - band_order = MODALITY_TO_WAVELENGTH_BANDWIDTHS[modality]["band_names"] wavelengths = [] bandwidths = [] - band_order = self.band_order.get(modality, None) - if band_order is None: - logger.warning(f"Band order for modality {modality} not found in band_order dictionary, unable to use this modality unless specified") + modality_band_order = self.band_order.get(modality, None) + if modality_band_order is None: + logger.warning( + f"Band order for modality {modality} not found in band_order dictionary, unable to use this modality unless specified" + ) continue - for b in band_order: + for b in modality_band_order: cfm_idx = wavelength_bandwidths["band_names"].index(b) wavelengths.append(wavelength_bandwidths["band_wavelengths"][cfm_idx]) bandwidths.append(wavelength_bandwidths["band_bandwidths"][cfm_idx]) @@ -113,9 +124,7 @@ def __init__( "band_wavelengths": wavelengths, } - def _resize_data( - self, data: torch.Tensor - ) -> list[torch.Tensor]: + def _resize_data(self, data: torch.Tensor) -> list[torch.Tensor]: """Process individual modality data. Args: @@ -126,9 +135,7 @@ def _resize_data( """ # Get original dimensions original_height = data.shape[2] - new_height = ( - self.patch_size if original_height == 1 else self.image_resolution - ) + new_height = self.patch_size if original_height == 1 else self.image_resolution data = F.interpolate( data, size=(new_height, new_height), @@ -140,14 +147,14 @@ def _resize_data( def prepare_input( self, inputs: dict[str, torch.Tensor], - ) -> tuple[list[torch.Tensor], list[int], list[int]]: + ) -> tuple[torch.Tensor, list[int], list[int]]: """Prepare input for the CopernicusFM model from MaskedHeliosSample.""" wavelengths: list[int] = [] bandwidths: list[int] = [] all_processed_data: list[list[torch.Tensor]] = [] for modality in inputs.keys(): if modality not in self.supported_modalities: - logger.warning( + logger.debug( f"Skipping modality {modality} as it is not in the supported " f"modalities list {self.supported_modalities}" ) @@ -174,11 +181,12 @@ def forward( inputs: list[dict[str, torch.Tensor]], ) -> torch.Tensor: """Forward pass through CopernicusFM model.""" - batch_inputs = {key: torch.stack([inp[key] for inp in inputs], dim=0) for key in inputs[0].keys()} + batch_inputs = { + key: torch.stack([inp[key] for inp in inputs], dim=0) + for key in inputs[0].keys() + } # Prepare input - data, wavelengths, bandwidths = self.prepare_input( - batch_inputs - ) + data, wavelengths, bandwidths = self.prepare_input(batch_inputs) meta = torch.full( (1, 4), float("nan"), device=data.device ) # [lon, lat, delta_time, patch_token_area], assume unknown @@ -191,7 +199,7 @@ def forward( None, self.input_mode, self.patch_size, - ) + ) # no norm, following # https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/foundation_models/CopernicusFM/models_dwv_seg.py side = math.isqrt(timestep_output.shape[1]) diff --git a/rslearn/models/copernicusfm_src/__init__.py b/rslearn/models/copernicusfm_src/__init__.py index b97b4284..1e3572cf 100644 --- a/rslearn/models/copernicusfm_src/__init__.py +++ b/rslearn/models/copernicusfm_src/__init__.py @@ -1 +1 @@ -# type: ignore +# mypy: ignore-errors diff --git a/rslearn/models/copernicusfm_src/dynamic_hypernetwork.py b/rslearn/models/copernicusfm_src/dynamic_hypernetwork.py index f2d5b121..694dc571 100644 --- a/rslearn/models/copernicusfm_src/dynamic_hypernetwork.py +++ b/rslearn/models/copernicusfm_src/dynamic_hypernetwork.py @@ -1,4 +1,4 @@ -# type: ignore +# mypy: ignore-errors import numpy as np import torch import torch.nn as nn diff --git a/rslearn/models/copernicusfm_src/model_vit.py b/rslearn/models/copernicusfm_src/model_vit.py index 09d519c1..1f3fa73c 100644 --- a/rslearn/models/copernicusfm_src/model_vit.py +++ b/rslearn/models/copernicusfm_src/model_vit.py @@ -1,4 +1,4 @@ -# type: ignore +# mypy: ignore-errors # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. diff --git a/rslearn/models/unet.py b/rslearn/models/unet.py index 6d06d11f..4bf500f6 100644 --- a/rslearn/models/unet.py +++ b/rslearn/models/unet.py @@ -3,6 +3,7 @@ from typing import Any import torch +import torch.nn.functional as F class UNetDecoder(torch.nn.Module): @@ -129,7 +130,12 @@ def __init__( def _resize(self, features: torch.Tensor) -> torch.Tensor: """Interpolate the features to the original size.""" - return F.interpolate(features, size=self.original_size_to_interpolate, mode="bilinear", align_corners=False) + return F.interpolate( + features, + size=self.original_size_to_interpolate, + mode="bilinear", + align_corners=False, + ) def forward( self, in_features: list[torch.Tensor], inputs: list[dict[str, Any]] diff --git a/tests/unit/models/test_copernicusfm.py b/tests/unit/models/test_copernicusfm.py index 3ad51679..fb98e174 100644 --- a/tests/unit/models/test_copernicusfm.py +++ b/tests/unit/models/test_copernicusfm.py @@ -1,16 +1,13 @@ """Test Copernicus FM model.""" -import pathlib -import tempfile -from typing import Any - +import pytest import torch from rslearn.models.copernicusfm import CopernicusFM, CopernicusFMModality -import pytest DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + @pytest.skipif(not torch.cuda.is_available(), reason="Requires a GPU") def test_copernicusfm() -> None: """Verify that the forward pass for CROMA works.""" @@ -19,16 +16,45 @@ def test_copernicusfm() -> None: # of this test. band_order = { - CopernicusFMModality.SENTINEL2_L2A.value: ["B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B11", "B12"], + CopernicusFMModality.SENTINEL2_L2A.value: [ + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B8A", + "B11", + "B12", + ], CopernicusFMModality.SENTINEL1.value: ["vv", "vh"], } inputs = [ { - "sentinel2": torch.zeros((len(band_order[CopernicusFMModality.SENTINEL2_L2A.value]), input_hw, input_hw), dtype=torch.float32, device=DEVICE), - "sentinel1": torch.zeros((len(band_order[CopernicusFMModality.SENTINEL1.value]), input_hw, input_hw), dtype=torch.float32, device=DEVICE), + "sentinel2": torch.zeros( + ( + len(band_order[CopernicusFMModality.SENTINEL2_L2A.value]), + input_hw, + input_hw, + ), + dtype=torch.float32, + device=DEVICE, + ), + "sentinel1": torch.zeros( + ( + len(band_order[CopernicusFMModality.SENTINEL1.value]), + input_hw, + input_hw, + ), + dtype=torch.float32, + device=DEVICE, + ), } ] with torch.no_grad(): copernicusfm = CopernicusFM(band_order=band_order).to(DEVICE) feature_list = copernicusfm(inputs) - assert feature_list[0].shape == torch.Size([1, 768, 14, 14]) and len(feature_list) == 1 \ No newline at end of file + assert ( + feature_list[0].shape == torch.Size([1, 768, 14, 14]) and len(feature_list) == 1 + ) From 8d32f12c1f907b344b32eb18a27a0393f96c588d Mon Sep 17 00:00:00 2001 From: hgherzog Date: Thu, 25 Sep 2025 21:28:25 +0000 Subject: [PATCH 08/14] fix test --- tests/unit/models/test_copernicusfm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/models/test_copernicusfm.py b/tests/unit/models/test_copernicusfm.py index fb98e174..1419438c 100644 --- a/tests/unit/models/test_copernicusfm.py +++ b/tests/unit/models/test_copernicusfm.py @@ -8,7 +8,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") -@pytest.skipif(not torch.cuda.is_available(), reason="Requires a GPU") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires a GPU") def test_copernicusfm() -> None: """Verify that the forward pass for CROMA works.""" input_hw = 32 From ba8b38e9f913f56008be6479ccc58ccc2636e69c Mon Sep 17 00:00:00 2001 From: hgherzog Date: Fri, 26 Sep 2025 07:48:10 -0400 Subject: [PATCH 09/14] adress comments --- rslearn/models/copernicusfm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/rslearn/models/copernicusfm.py b/rslearn/models/copernicusfm.py index 283afe4f..65228387 100644 --- a/rslearn/models/copernicusfm.py +++ b/rslearn/models/copernicusfm.py @@ -80,7 +80,7 @@ class CopernicusFM(torch.nn.Module): def __init__( self, band_order: dict[str, list[str]], - load_directory: str = "/weka/dfive-default/helios/models/copernicusfm", + load_directory: str, ) -> None: """Initialize the Copernicus FM wrapper. @@ -94,7 +94,8 @@ def __init__( self.band_order = band_order self.model = vit_base_patch16(num_classes=10, global_pool=True) check_point = torch.load( - UPath(load_directory) / "CopernicusFM_ViT_base_varlang_e100.pth" # nosec B614 + UPath(load_directory) / "CopernicusFM_ViT_base_varlang_e100.pth", + weights_only=True, ) if "model" in check_point: state_dict = check_point["model"] From 7cd55da8ca20be33566b2d5685db72a9cef6171f Mon Sep 17 00:00:00 2001 From: hgherzog Date: Fri, 26 Sep 2025 07:50:47 -0400 Subject: [PATCH 10/14] start cleaning up modules --- .../copernicusfm_src/dynamic_hypernetwork.py | 12 +- rslearn/models/copernicusfm_src/util/misc.py | 404 ------------------ 2 files changed, 3 insertions(+), 413 deletions(-) delete mode 100644 rslearn/models/copernicusfm_src/util/misc.py diff --git a/rslearn/models/copernicusfm_src/dynamic_hypernetwork.py b/rslearn/models/copernicusfm_src/dynamic_hypernetwork.py index 694dc571..e8df04b9 100644 --- a/rslearn/models/copernicusfm_src/dynamic_hypernetwork.py +++ b/rslearn/models/copernicusfm_src/dynamic_hypernetwork.py @@ -12,12 +12,6 @@ from .flexivit.patch_embed import pi_resize_patch_embed from .util.pos_embed import get_1d_sincos_pos_embed_from_grid_torch -# from torchvision.datasets.utils import download_url - - -random_seed = 1234 -torch.manual_seed(random_seed) - class TransformerWeightGenerator(nn.Module): def __init__(self, input_dim, output_dim, embed_dim, num_heads=4, num_layers=1): @@ -82,9 +76,9 @@ def forward(self, x): batches, channels, width, height = x.shape - assert channels == self._num_input_channels, ( - f"Expected input to have {self._num_input_channels} channels (got {channels} channels)" - ) + assert ( + channels == self._num_input_channels + ), f"Expected input to have {self._num_input_channels} channels (got {channels} channels)" # Make shape compatible for matmul with _B. # From [B, C, W, H] to [(B*W*H), C]. diff --git a/rslearn/models/copernicusfm_src/util/misc.py b/rslearn/models/copernicusfm_src/util/misc.py deleted file mode 100644 index a2f59abb..00000000 --- a/rslearn/models/copernicusfm_src/util/misc.py +++ /dev/null @@ -1,404 +0,0 @@ -# type: ignore - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -------------------------------------------------------- -# References: -# DeiT: https://github.com/facebookresearch/deit -# BEiT: https://github.com/microsoft/unilm/tree/master/beit -# -------------------------------------------------------- - -import builtins -import datetime -import os -import time -from collections import defaultdict, deque -from pathlib import Path - -import torch -import torch.distributed as dist - -# from torch._six import inf -from torch import inf - - -class SmoothedValue: - """Track a series of values and provide access to smoothed values over a - window or the global series average. - """ - - def __init__(self, window_size=20, fmt=None): - if fmt is None: - fmt = "{median:.4f} ({global_avg:.4f})" - self.deque = deque(maxlen=window_size) - self.total = 0.0 - self.count = 0 - self.fmt = fmt - - def update(self, value, n=1): - self.deque.append(value) - self.count += n - self.total += value * n - - def synchronize_between_processes(self): - """Warning: does not synchronize the deque!""" - if not is_dist_avail_and_initialized(): - return - t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") - dist.barrier() - dist.all_reduce(t) - t = t.tolist() - self.count = int(t[0]) - self.total = t[1] - - @property - def median(self): - d = torch.tensor(list(self.deque)) - return d.median().item() - - @property - def avg(self): - d = torch.tensor(list(self.deque), dtype=torch.float32) - return d.mean().item() - - @property - def global_avg(self): - return self.total / self.count - - @property - def max(self): - return max(self.deque) - - @property - def value(self): - return self.deque[-1] - - def __str__(self): - return self.fmt.format( - median=self.median, - avg=self.avg, - global_avg=self.global_avg, - max=self.max, - value=self.value, - ) - - -class MetricLogger: - def __init__(self, delimiter="\t"): - self.meters = defaultdict(SmoothedValue) - self.delimiter = delimiter - - def update(self, **kwargs): - for k, v in kwargs.items(): - if v is None: - continue - if isinstance(v, torch.Tensor): - v = v.item() - assert isinstance(v, (float, int)) - self.meters[k].update(v) - - def __getattr__(self, attr): - if attr in self.meters: - return self.meters[attr] - if attr in self.__dict__: - return self.__dict__[attr] - raise AttributeError( - f"'{type(self).__name__}' object has no attribute '{attr}'" - ) - - def __str__(self): - loss_str = [] - for name, meter in self.meters.items(): - loss_str.append(f"{name}: {str(meter)}") - return self.delimiter.join(loss_str) - - def synchronize_between_processes(self): - for meter in self.meters.values(): - meter.synchronize_between_processes() - - def add_meter(self, name, meter): - self.meters[name] = meter - - def log_every(self, iterable, print_freq, header=None): - i = 0 - if not header: - header = "" - start_time = time.time() - end = time.time() - iter_time = SmoothedValue(fmt="{avg:.4f}") - data_time = SmoothedValue(fmt="{avg:.4f}") - space_fmt = ":" + str(len(str(len(iterable)))) + "d" - log_msg = [ - header, - "[{0" + space_fmt + "}/{1}]", - "eta: {eta}", - "{meters}", - "time: {time}", - "data: {data}", - ] - if torch.cuda.is_available(): - log_msg.append("max mem: {memory:.0f}") - log_msg = self.delimiter.join(log_msg) - MB = 1024.0 * 1024.0 - for obj in iterable: - data_time.update(time.time() - end) - yield obj - iter_time.update(time.time() - end) - if i % print_freq == 0 or i == len(iterable) - 1: - eta_seconds = iter_time.global_avg * (len(iterable) - i) - eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) - if torch.cuda.is_available(): - print( - log_msg.format( - i, - len(iterable), - eta=eta_string, - meters=str(self), - time=str(iter_time), - data=str(data_time), - memory=torch.cuda.max_memory_allocated() / MB, - ) - ) - else: - print( - log_msg.format( - i, - len(iterable), - eta=eta_string, - meters=str(self), - time=str(iter_time), - data=str(data_time), - ) - ) - i += 1 - end = time.time() - total_time = time.time() - start_time - total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print( - f"{header} Total time: {total_time_str} ({total_time / len(iterable):.4f} s / it)" - ) - - -def setup_for_distributed(is_master): - """This function disables printing when not in master process""" - builtin_print = builtins.print - - def print(*args, **kwargs): - force = kwargs.pop("force", False) - force = force or (get_world_size() > 16) - if is_master or force: - now = datetime.datetime.now().time() - builtin_print(f"[{now}] ", end="") # print with time stamp - builtin_print(*args, **kwargs) - - builtins.print = print - - -def is_dist_avail_and_initialized(): - if not dist.is_available(): - return False - if not dist.is_initialized(): - return False - return True - - -def get_world_size(): - if not is_dist_avail_and_initialized(): - return 1 - return dist.get_world_size() - - -def get_rank(): - if not is_dist_avail_and_initialized(): - return 0 - return dist.get_rank() - - -def is_main_process(): - return get_rank() == 0 - - -def save_on_master(*args, **kwargs): - if is_main_process(): - torch.save(*args, **kwargs) - - -def init_distributed_mode(args): - if args.dist_on_itp: - args.rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) - args.world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) - args.gpu = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) - args.dist_url = "tcp://%s:%s" % ( - os.environ["MASTER_ADDR"], - os.environ["MASTER_PORT"], - ) - os.environ["LOCAL_RANK"] = str(args.gpu) - os.environ["RANK"] = str(args.rank) - os.environ["WORLD_SIZE"] = str(args.world_size) - # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] - elif "RANK" in os.environ and "WORLD_SIZE" in os.environ: - args.rank = int(os.environ["RANK"]) - args.world_size = int(os.environ["WORLD_SIZE"]) - args.gpu = int(os.environ["LOCAL_RANK"]) - elif "SLURM_PROCID" in os.environ: - args.rank = int(os.environ["SLURM_PROCID"]) - if "SLURM_LOCALID" in os.environ: - local_rank = int(os.environ["SLURM_LOCALID"]) - # If only one GPU is visible, force to 0 - if torch.cuda.device_count() == 1: - args.gpu = 0 - else: - args.gpu = local_rank - else: - args.gpu = args.rank % torch.cuda.device_count() - args.world_size = int(os.environ["SLURM_NNODES"]) * int( - os.environ["SLURM_TASKS_PER_NODE"][0] - ) - else: - print("Not using distributed mode") - setup_for_distributed(is_master=True) # hack - args.distributed = False - return - - args.distributed = True - - print("procid", int(os.environ.get("SLURM_PROCID"))) - print("localid", int(os.environ.get("SLURM_LOCALID"))) - print("devicecount", torch.cuda.device_count()) - print("CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES")) - - torch.cuda.set_device(args.gpu) - args.dist_backend = "nccl" - print( - f"World Size {args.world_size} | distributed init (rank {args.rank}): {args.dist_url}, gpu {args.gpu}", - flush=True, - ) - torch.distributed.init_process_group( - backend=args.dist_backend, - init_method=args.dist_url, - world_size=args.world_size, - rank=args.rank, - ) - print("distributed init done") - torch.distributed.barrier() - setup_for_distributed(args.rank == 0) - print("setup print done") - - -class NativeScalerWithGradNormCount: - state_dict_key = "amp_scaler" - - def __init__(self): - self._scaler = torch.cuda.amp.GradScaler() - - def __call__( - self, - loss, - optimizer, - clip_grad=None, - parameters=None, - create_graph=False, - update_grad=True, - ): - self._scaler.scale(loss).backward(create_graph=create_graph) - if update_grad: - if clip_grad is not None: - assert parameters is not None - self._scaler.unscale_( - optimizer - ) # unscale the gradients of optimizer's assigned params in-place - norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) - else: - self._scaler.unscale_(optimizer) - norm = get_grad_norm_(parameters) - self._scaler.step(optimizer) - self._scaler.update() - else: - norm = None - return norm - - def state_dict(self): - return self._scaler.state_dict() - - def load_state_dict(self, state_dict): - self._scaler.load_state_dict(state_dict) - - -def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - parameters = [p for p in parameters if p.grad is not None] - norm_type = float(norm_type) - if len(parameters) == 0: - return torch.tensor(0.0) - device = parameters[0].grad.device - if norm_type == inf: - total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) - else: - total_norm = torch.norm( - torch.stack( - [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters] - ), - norm_type, - ) - return total_norm - - -def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): - output_dir = Path(args.output_dir) - epoch_name = str(epoch) - if loss_scaler is not None: - checkpoint_paths = [output_dir / ("checkpoint-%s.pth" % epoch_name)] - for checkpoint_path in checkpoint_paths: - to_save = { - "model": model_without_ddp.state_dict(), - "optimizer": optimizer.state_dict(), - "epoch": epoch, - "scaler": loss_scaler.state_dict(), - "args": args, - } - - save_on_master(to_save, checkpoint_path) - else: - client_state = {"epoch": epoch} - model.save_checkpoint( - save_dir=args.output_dir, - tag="checkpoint-%s" % epoch_name, - client_state=client_state, - ) - - -def load_model(args, model_without_ddp, optimizer, loss_scaler): - if args.resume: - if args.resume.startswith("https"): - checkpoint = torch.hub.load_state_dict_from_url( - args.resume, map_location="cpu", check_hash=True - ) - else: - checkpoint = torch.load(args.resume, map_location="cpu") - model_without_ddp.load_state_dict(checkpoint["model"]) - print("Resume checkpoint %s" % args.resume) - if ( - "optimizer" in checkpoint - and "epoch" in checkpoint - and not (hasattr(args, "eval") and args.eval) - ): - optimizer.load_state_dict(checkpoint["optimizer"]) - args.start_epoch = checkpoint["epoch"] + 1 - if "scaler" in checkpoint: - loss_scaler.load_state_dict(checkpoint["scaler"]) - print("With optim & sched!") - - -def all_reduce_mean(x): - world_size = get_world_size() - if world_size > 1: - x_reduce = torch.tensor(x).cuda() - dist.all_reduce(x_reduce) - x_reduce /= world_size - return x_reduce.item() - else: - return x From d8e1fa3a8acd02c5646362800e387fd4db46b299 Mon Sep 17 00:00:00 2001 From: hgherzog Date: Fri, 26 Sep 2025 19:45:59 -0400 Subject: [PATCH 11/14] remove functions --- .../models/copernicusfm_src/aurora/area.py | 98 +------------------ 1 file changed, 1 insertion(+), 97 deletions(-) diff --git a/rslearn/models/copernicusfm_src/aurora/area.py b/rslearn/models/copernicusfm_src/aurora/area.py index d980e66c..47599b2e 100644 --- a/rslearn/models/copernicusfm_src/aurora/area.py +++ b/rslearn/models/copernicusfm_src/aurora/area.py @@ -2,7 +2,7 @@ import torch -__all__ = ["area", "compute_patch_areas", "radius_earth"] +__all__ = ["area", "radius_earth"] # float: Radius of the earth in kilometers. @@ -48,99 +48,3 @@ def area(polygon: torch.Tensor) -> torch.Tensor: area = area * radius_earth * radius_earth / 2 return torch.abs(area) - - -def expand_matrix(matrix: torch.Tensor) -> torch.Tensor: - """Expand matrix by adding one row and one column to each side, using - linear interpolation. - - Args: - matrix (:class:`torch.Tensor`): Matrix to expand. - - Returns: - :class:`torch.Tensor`: `matrix`, but with two extra rows and two extra columns. - """ - # Add top and bottom rows. - matrix = torch.cat( - ( - 2 * matrix[0:1] - matrix[1:2], - matrix, - 2 * matrix[-1:] - matrix[-2:-1], - ), - dim=0, - ) - - # Add left and right columns. - matrix = torch.cat( - ( - 2 * matrix[:, 0:1] - matrix[:, 1:2], - matrix, - 2 * matrix[:, -1:] - matrix[:, -2:-1], - ), - dim=1, - ) - - return matrix - - -def compute_patch_areas(lat: torch.Tensor, lon: torch.Tensor) -> torch.Tensor: - """A pair of latitude and longitude matrices defines a number non-intersecting patches on the - Earth. For a global grid, these patches span the entire surface of the Earth. For a local grid, - the patches might span only a country or a continent. This function computes the area of every - specified patch. - - To divide the Earth into patches, the idea is to let a grid point be the _center_ of the - corresponding patch. The vertices of this patch will then sit exactly inbetween the grid - point and the grid points immediately diagonally and non-diagonally above, below, left, and - right. For a grid point at the very top of the grid, for example, there is no immediately above - grid point. In that case, we enlarge the grid by a row at the top by linearly interpolating the - latitudinal progression. - - Summary of algorithm: - 1. Enlarge the latitude and longitude matrices by adding one row and one column to each side. - 2. Calculate the patch vertices by averaging every 2x2 square in the enlarged grid. We also - call these points the midpoints. - 3. By using the vertices of the patches, i.e. the midpoints, compute the areas of the patches. - - Args: - lat (:class:`torch.Tensor`): Latitude matrix. Must be decreasing along rows. - lon (:class:`torch.Tensor`): Longitude matrix. Must be increasing along columns. - - Returns: - :class:`torch.Tensor`: Areas in square kilometer. - """ - if not (lat.dim() == lon.dim() == 2): - raise ValueError("`lat` and `lon` must both be matrices.") - if lat.shape != lat.shape: - raise ValueError("`lat` and `lon` must have the same shape.") - - # Check that the latitude matrix is decreasing in the appropriate way. - if not torch.all(lat[1:] - lat[:-1] <= 0): - raise ValueError("`lat` must be decreasing along rows.") - - # Check that the longitude matrix is increasing in the appropriate way. - if not torch.all(lon[:, 1:] - lon[:, :-1] >= 0): - raise ValueError("`lon` must be increasing along columns.") - - # Enlarge the latitude and longitude matrices for the midpoint computation. - lat = expand_matrix(lat) - lon = expand_matrix(lon) - - # Latitudes cannot expand beyond the poles. - lat = torch.clamp(lat, -90, 90) - - # Calculate midpoints between entries in lat/lon. This is very important for symmetry of the - # resulting areas. - lat_midpoints = (lat[:-1, :-1] + lat[:-1, 1:] + lat[1:, :-1] + lat[1:, 1:]) / 4 - lon_midpoints = (lon[:-1, :-1] + lon[:-1, 1:] + lon[1:, :-1] + lon[1:, 1:]) / 4 - - # Determine squares and return the area of those squares. - top_left = torch.stack((lat_midpoints[1:, :-1], lon_midpoints[1:, :-1]), dim=-1) - top_right = torch.stack((lat_midpoints[1:, 1:], lon_midpoints[1:, 1:]), dim=-1) - bottom_left = torch.stack( - (lat_midpoints[:-1, :-1], lon_midpoints[:-1, :-1]), dim=-1 - ) - bottom_right = torch.stack((lat_midpoints[:-1, 1:], lon_midpoints[:-1, 1:]), dim=-1) - polygon = torch.stack((top_left, top_right, bottom_right, bottom_left), dim=-2) - - return area(polygon) From e36f57244266317a158fe22f8f66cc8e32532e94 Mon Sep 17 00:00:00 2001 From: hgherzog Date: Fri, 26 Sep 2025 19:47:56 -0400 Subject: [PATCH 12/14] remove functions --- .../models/copernicusfm_src/util/lr_sched.py | 29 ------------------- 1 file changed, 29 deletions(-) delete mode 100644 rslearn/models/copernicusfm_src/util/lr_sched.py diff --git a/rslearn/models/copernicusfm_src/util/lr_sched.py b/rslearn/models/copernicusfm_src/util/lr_sched.py deleted file mode 100644 index 0044c160..00000000 --- a/rslearn/models/copernicusfm_src/util/lr_sched.py +++ /dev/null @@ -1,29 +0,0 @@ -# type: ignore -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import math - - -def adjust_learning_rate(optimizer, epoch, args): - """Decay the learning rate with half-cycle cosine after warmup""" - if epoch < args.warmup_epochs: - lr = args.lr * epoch / args.warmup_epochs - else: - lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * ( - 1.0 - + math.cos( - math.pi - * (epoch - args.warmup_epochs) - / (args.epochs - args.warmup_epochs) - ) - ) - for param_group in optimizer.param_groups: - if "lr_scale" in param_group: - param_group["lr"] = lr * param_group["lr_scale"] - else: - param_group["lr"] = lr - return lr From 492f169330663133657200b03ec1aaf3b46fa610 Mon Sep 17 00:00:00 2001 From: hgherzog Date: Fri, 26 Sep 2025 19:54:39 -0400 Subject: [PATCH 13/14] fix lints --- rslearn/models/copernicusfm.py | 23 ++++++++++++----------- tests/unit/models/test_copernicusfm.py | 4 ++-- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/rslearn/models/copernicusfm.py b/rslearn/models/copernicusfm.py index 65228387..2e19311a 100644 --- a/rslearn/models/copernicusfm.py +++ b/rslearn/models/copernicusfm.py @@ -80,28 +80,29 @@ class CopernicusFM(torch.nn.Module): def __init__( self, band_order: dict[str, list[str]], - load_directory: str, + load_directory: str | None, ) -> None: """Initialize the Copernicus FM wrapper. Args: band_order: The band order for each modality - load_directory: The directory to load from + load_directory: The directory to load from, if None no weights are loaded """ super().__init__() # global_pool=True so that we initialize the fc_norm layer self.band_order = band_order self.model = vit_base_patch16(num_classes=10, global_pool=True) - check_point = torch.load( - UPath(load_directory) / "CopernicusFM_ViT_base_varlang_e100.pth", - weights_only=True, - ) - if "model" in check_point: - state_dict = check_point["model"] - else: - state_dict = check_point - self.model.load_state_dict(state_dict, strict=False) + if load_directory is not None: + check_point = torch.load( + UPath(load_directory) / "CopernicusFM_ViT_base_varlang_e100.pth", + weights_only=True, + ) + if "model" in check_point: + state_dict = check_point["model"] + else: + state_dict = check_point + self.model.load_state_dict(state_dict, strict=False) # take MODALITY_TO_WAVELENGTH_BANDWIDTHS and rearrage it so that it has the same # ordering as the Helios band orders, defined by Modality.band_order diff --git a/tests/unit/models/test_copernicusfm.py b/tests/unit/models/test_copernicusfm.py index 1419438c..2b98565f 100644 --- a/tests/unit/models/test_copernicusfm.py +++ b/tests/unit/models/test_copernicusfm.py @@ -52,9 +52,9 @@ def test_copernicusfm() -> None: ), } ] + copernicusfm = CopernicusFM(band_order=band_order, load_directory=None).to(DEVICE) with torch.no_grad(): - copernicusfm = CopernicusFM(band_order=band_order).to(DEVICE) - feature_list = copernicusfm(inputs) + feature_list = copernicusfm(inputs) assert ( feature_list[0].shape == torch.Size([1, 768, 14, 14]) and len(feature_list) == 1 ) From f78fa85ac972f93c52625ace4a294aea65c23a62 Mon Sep 17 00:00:00 2001 From: hgherzog Date: Tue, 30 Sep 2025 08:15:45 -0700 Subject: [PATCH 14/14] fix type --- rslearn/models/copernicusfm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rslearn/models/copernicusfm.py b/rslearn/models/copernicusfm.py index 2e19311a..89b76725 100644 --- a/rslearn/models/copernicusfm.py +++ b/rslearn/models/copernicusfm.py @@ -126,7 +126,7 @@ def __init__( "band_wavelengths": wavelengths, } - def _resize_data(self, data: torch.Tensor) -> list[torch.Tensor]: + def _resize_data(self, data: torch.Tensor) -> torch.Tensor: """Process individual modality data. Args: