-
Notifications
You must be signed in to change notification settings - Fork 7
Add Copernicus FM #293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add Copernicus FM #293
Changes from 8 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
fe4b7ee
allow unet to resize outputs
Hgherzog 75a0775
Add source files
Hgherzog d2b533c
add source files
Hgherzog 5b94d89
update path
Hgherzog 6a5edb9
rename for name space conflict
Hgherzog 03ba136
add test
Hgherzog c261d5f
linting
Hgherzog 8d32f12
fix test
Hgherzog 0242917
Merge branch 'master' into henryh/add-copernicus-fm
Hgherzog ba8b38e
adress comments
Hgherzog ac8871e
Merge branch 'henryh/add-copernicus-fm' of https://github.com/allenai…
Hgherzog 7cd55da
start cleaning up modules
Hgherzog 24cb398
Merge branch 'master' into henryh/add-copernicus-fm
Hgherzog d8e1fa3
remove functions
Hgherzog 2cf4fd3
Merge branch 'henryh/add-copernicus-fm' of https://github.com/allenai…
Hgherzog e36f572
remove functions
Hgherzog 492f169
fix lints
Hgherzog f78fa85
fix type
Hgherzog c919e83
Merge branch 'master' into henryh/add-copernicus-fm
Hgherzog 29c57ad
Merge branch 'master' into henryh/add-copernicus-fm
Hgherzog File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,214 @@ | ||
| """Copernicus FM model.""" | ||
|
|
||
| import logging | ||
| import math | ||
| from enum import Enum | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from einops import rearrange | ||
| from upath import UPath | ||
|
|
||
| from .copernicusfm_src.model_vit import vit_base_patch16 | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class CopernicusFMModality(Enum): | ||
| """Modality for Copernicus FM.""" | ||
|
|
||
| SENTINEL2_L2A = "sentinel2_l2a" | ||
| SENTINEL1 = "sentinel1" | ||
|
|
||
|
|
||
| MODALITY_TO_WAVELENGTH_BANDWIDTHS: dict[str, dict[str, list]] = { | ||
| # https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/configs/dataset/cobench_eurosat_s2.yaml | ||
| CopernicusFMModality.SENTINEL2_L2A.value: { | ||
| "band_names": [ | ||
| "B01", | ||
| "B02", | ||
| "B03", | ||
| "B04", | ||
| "B05", | ||
| "B06", | ||
| "B07", | ||
| "B08", | ||
| "B8A", | ||
| "B09", | ||
| "B10", | ||
| "B11", | ||
| "B12", | ||
| ], | ||
| "band_wavelengths": [ | ||
| 440, | ||
| 490, | ||
| 560, | ||
| 665, | ||
| 705, | ||
| 740, | ||
| 783, | ||
| 842, | ||
| 860, | ||
| 940, | ||
| 1370, | ||
| 1610, | ||
| 2190, | ||
| ], | ||
| "band_bandwidths": [20, 65, 35, 30, 15, 15, 20, 115, 20, 20, 30, 90, 180], | ||
| }, | ||
| # https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/configs/dataset/cobench_eurosat_s1.yaml | ||
| CopernicusFMModality.SENTINEL1.value: { | ||
| "band_names": ["vv", "vh"], | ||
| "band_wavelengths": [50000000, 50000000], | ||
| "band_bandwidths": [1e9, 1e9], | ||
| }, | ||
| } | ||
|
|
||
|
|
||
| class CopernicusFM(torch.nn.Module): | ||
| """Wrapper for Copernicus FM to ingest Masked Helios Sample.""" | ||
|
|
||
| image_resolution = 224 | ||
| patch_size = 16 | ||
| input_mode = "spectral" | ||
| # Don't need this as band order is provided | ||
| supported_modalities = [ | ||
| CopernicusFMModality.SENTINEL2_L2A.value, | ||
| CopernicusFMModality.SENTINEL1.value, | ||
| ] | ||
|
|
||
| def __init__( | ||
| self, | ||
| band_order: dict[str, list[str]], | ||
| load_directory: str = "/weka/dfive-default/helios/models/copernicusfm", | ||
Hgherzog marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) -> None: | ||
| """Initialize the Copernicus FM wrapper. | ||
|
|
||
| Args: | ||
| band_order: The band order for each modality | ||
| load_directory: The directory to load from | ||
| """ | ||
| super().__init__() | ||
|
|
||
| # global_pool=True so that we initialize the fc_norm layer | ||
| self.band_order = band_order | ||
| self.model = vit_base_patch16(num_classes=10, global_pool=True) | ||
| check_point = torch.load( | ||
| UPath(load_directory) / "CopernicusFM_ViT_base_varlang_e100.pth" # nosec B614 | ||
Hgherzog marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) | ||
| if "model" in check_point: | ||
| state_dict = check_point["model"] | ||
| else: | ||
| state_dict = check_point | ||
| self.model.load_state_dict(state_dict, strict=False) | ||
|
|
||
| # take MODALITY_TO_WAVELENGTH_BANDWIDTHS and rearrage it so that it has the same | ||
| # ordering as the Helios band orders, defined by Modality.band_order | ||
| self.modality_to_wavelength_bandwidths = {} | ||
| for modality in self.supported_modalities: | ||
| wavelength_bandwidths = MODALITY_TO_WAVELENGTH_BANDWIDTHS[modality] | ||
| wavelengths = [] | ||
| bandwidths = [] | ||
| modality_band_order = self.band_order.get(modality, None) | ||
| if modality_band_order is None: | ||
| logger.warning( | ||
| f"Band order for modality {modality} not found in band_order dictionary, unable to use this modality unless specified" | ||
| ) | ||
| continue | ||
| for b in modality_band_order: | ||
| cfm_idx = wavelength_bandwidths["band_names"].index(b) | ||
| wavelengths.append(wavelength_bandwidths["band_wavelengths"][cfm_idx]) | ||
| bandwidths.append(wavelength_bandwidths["band_bandwidths"][cfm_idx]) | ||
| self.modality_to_wavelength_bandwidths[modality] = { | ||
| "band_bandwidths": bandwidths, | ||
| "band_wavelengths": wavelengths, | ||
| } | ||
|
|
||
| def _resize_data(self, data: torch.Tensor) -> list[torch.Tensor]: | ||
Hgherzog marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Process individual modality data. | ||
|
|
||
| Args: | ||
| data: Input tensor of shape [B, C, H, W] | ||
|
|
||
| Returns: | ||
| list of tensors of shape [B, C, H, W] | ||
| """ | ||
| # Get original dimensions | ||
| original_height = data.shape[2] | ||
| new_height = self.patch_size if original_height == 1 else self.image_resolution | ||
| data = F.interpolate( | ||
| data, | ||
| size=(new_height, new_height), | ||
| mode="bilinear", | ||
| align_corners=False, | ||
| ) | ||
| return data | ||
|
|
||
| def prepare_input( | ||
| self, | ||
| inputs: dict[str, torch.Tensor], | ||
| ) -> tuple[torch.Tensor, list[int], list[int]]: | ||
| """Prepare input for the CopernicusFM model from MaskedHeliosSample.""" | ||
| wavelengths: list[int] = [] | ||
| bandwidths: list[int] = [] | ||
| all_processed_data: list[list[torch.Tensor]] = [] | ||
| for modality in inputs.keys(): | ||
| if modality not in self.supported_modalities: | ||
| logger.debug( | ||
| f"Skipping modality {modality} as it is not in the supported " | ||
| f"modalities list {self.supported_modalities}" | ||
| ) | ||
| continue | ||
|
|
||
| data = inputs[modality] | ||
|
|
||
| if data is None: | ||
| continue | ||
|
|
||
| all_processed_data.append(self._resize_data(data)) | ||
| wavelengths.extend( | ||
| self.modality_to_wavelength_bandwidths[modality]["band_wavelengths"] | ||
| ) | ||
| bandwidths.extend( | ||
| self.modality_to_wavelength_bandwidths[modality]["band_bandwidths"] | ||
| ) | ||
|
|
||
| concatenated_processed_data = torch.cat(all_processed_data, dim=1) | ||
| return concatenated_processed_data, wavelengths, bandwidths | ||
|
|
||
| def forward( | ||
| self, | ||
| inputs: list[dict[str, torch.Tensor]], | ||
| ) -> torch.Tensor: | ||
| """Forward pass through CopernicusFM model.""" | ||
| batch_inputs = { | ||
| key: torch.stack([inp[key] for inp in inputs], dim=0) | ||
| for key in inputs[0].keys() | ||
| } | ||
| # Prepare input | ||
| data, wavelengths, bandwidths = self.prepare_input(batch_inputs) | ||
| meta = torch.full( | ||
| (1, 4), float("nan"), device=data.device | ||
| ) # [lon, lat, delta_time, patch_token_area], assume unknown | ||
Hgherzog marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # "The embed tensor contains the encoded image features, which can be used for downstream tasks." | ||
| _, timestep_output = self.model( | ||
| data, | ||
| meta, | ||
| wavelengths, | ||
| bandwidths, | ||
| None, | ||
| self.input_mode, | ||
| self.patch_size, | ||
| ) | ||
| # no norm, following | ||
| # https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/foundation_models/CopernicusFM/models_dwv_seg.py | ||
| side = math.isqrt(timestep_output.shape[1]) | ||
| output_features = rearrange( | ||
| timestep_output, "b (h w) c -> b c h w ", h=side, w=side | ||
| ) | ||
| return [output_features] | ||
|
|
||
| def get_backbone_channels(self) -> list[tuple[int, int]]: | ||
| """Returns the output channels of this model when used as a backbone.""" | ||
| # TODO: load this from a constant depending on the model size | ||
| return [(self.patch_size, 768)] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| # mypy: ignore-errors |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.