diff --git a/data/helios/v2_mozambique_lulc/finetune_s1_s2.yaml b/data/helios/v2_mozambique_lulc/finetune_s1_s2.yaml new file mode 100644 index 00000000..dfff8df9 --- /dev/null +++ b/data/helios/v2_mozambique_lulc/finetune_s1_s2.yaml @@ -0,0 +1,250 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth + init_args: + checkpoint_path: /weka/dfive-default/helios/checkpoints/henryh/base_v6.1_add_chm_cdl_worldcereal/step500000 + selector: ["encoder"] + forward_kwargs: + patch_size: 1 + decoders: + crop_type_classification: + - class_path: rslp.nandi.train.SegmentationPoolingDecoder + init_args: + in_channels: 768 + out_channels: 8 + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + lr: 0.0001 + scheduler: + class_path: rslearn.train.scheduler.PlateauScheduler + init_args: + factor: 0.2 + patience: 2 + min_lr: 0 + cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: /weka/dfive-default/rslearn-eai/datasets/crop/mozambique_lulc + inputs: + sentinel2_l2a: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + load_all_item_groups: true + load_all_layers: true + sentinel1: + data_type: "raster" + layers: ["sentinel1_descending"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + load_all_item_groups: true + load_all_layers: true + label: + data_type: "raster" + layers: ["label_raster"] + bands: ["label"] + is_target: true + dtype: INT32 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + crop_type_classification: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 8 + zero_is_invalid: true + metric_kwargs: + average: "micro" + other_metrics: + water_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 1 + water_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 1 + bareground_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 2 + bareground_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 2 + rangeland_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 3 + rangeland_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 3 + floodedvegetation_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 4 + floodedvegetation_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 4 + trees_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 5 + trees_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 5 + cropland_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 6 + cropland_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 6 + buildings_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 7 + buildings_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 7 + input_mapping: + crop_type_classification: + label: "targets" + batch_size: 32 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + sentinel1: ["vv", "vh"] + - class_path: rslearn.train.transforms.pad.Pad + init_args: + size: 4 + mode: "center" + image_selectors: ["sentinel2_l2a", "sentinel1", "target/crop_type_classification/classes", "target/crop_type_classification/valid"] + train_config: + groups: ["gaza"] + tags: + split: "train" + val_config: + groups: ["gaza"] + tags: + split: "test" + test_config: + groups: ["gaza"] + tags: + split: "test" +trainer: + max_epochs: 100 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_loss + mode: min + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 20 + unfreeze_lr_factor: 10 + # - class_path: rslearn.train.prediction_writer.RslearnWriter + # init_args: + # path: /weka/dfive-default/rslearn-eai/datasets/crop/mozambique_lulc + # output_layer: prediction_v0 + # selector: ["crop_type_classification"] + # merger: + # class_path: rslearn.train.prediction_writer.RasterMerger + # init_args: + # padding: 2 +rslp_project: 2025_09_18_mozambique_lulc +rslp_experiment: mozambique_lulc_helios_base_S1_S2_ts_ws4_ps1_gaza diff --git a/data/helios/v2_mozambique_lulc/finetune_s2.yaml b/data/helios/v2_mozambique_lulc/finetune_s2.yaml new file mode 100644 index 00000000..0115655c --- /dev/null +++ b/data/helios/v2_mozambique_lulc/finetune_s2.yaml @@ -0,0 +1,241 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth + init_args: + checkpoint_path: /weka/dfive-default/helios/checkpoints/henryh/base_v6.1_add_chm_cdl_worldcereal/step500000 + selector: ["encoder"] + forward_kwargs: + patch_size: 1 + decoders: + crop_type_classification: + - class_path: rslp.nandi.train.SegmentationPoolingDecoder + init_args: + in_channels: 768 + out_channels: 8 + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + lr: 0.0001 + scheduler: + class_path: rslearn.train.scheduler.PlateauScheduler + init_args: + factor: 0.2 + patience: 2 + min_lr: 0 + cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: /weka/dfive-default/rslearn-eai/datasets/crop/mozambique_lulc + inputs: + sentinel2_l2a: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + load_all_item_groups: true + load_all_layers: true + label: + data_type: "raster" + layers: ["label_raster"] + bands: ["label"] + is_target: true + dtype: INT32 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + crop_type_classification: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 8 + zero_is_invalid: true + metric_kwargs: + average: "micro" + other_metrics: + water_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 1 + water_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 1 + bareground_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 2 + bareground_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 2 + rangeland_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 3 + rangeland_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 3 + floodedvegetation_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 4 + floodedvegetation_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 4 + trees_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 5 + trees_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 5 + cropland_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 6 + cropland_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 6 + buildings_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 7 + buildings_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 7 + input_mapping: + crop_type_classification: + label: "targets" + batch_size: 32 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + - class_path: rslearn.train.transforms.pad.Pad + init_args: + size: 4 + mode: "center" + image_selectors: ["sentinel2_l2a", "target/crop_type_classification/classes", "target/crop_type_classification/valid"] + train_config: + groups: ["gaza"] + tags: + split: "train" + val_config: + groups: ["gaza"] + tags: + split: "test" + test_config: + groups: ["gaza"] + tags: + split: "test" +trainer: + max_epochs: 100 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_loss + mode: min + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 20 + unfreeze_lr_factor: 10 + # - class_path: rslearn.train.prediction_writer.RslearnWriter + # init_args: + # path: /weka/dfive-default/rslearn-eai/datasets/crop/mozambique_lulc + # output_layer: prediction_v0 + # selector: ["crop_type_classification"] + # merger: + # class_path: rslearn.train.prediction_writer.RasterMerger + # init_args: + # padding: 2 +rslp_project: 2025_09_18_mozambique_lulc +rslp_experiment: mozambique_lulc_helios_base_S2_ts_ws4_ps1_gaza diff --git a/data/helios/v2_mozambique_lulc/finetune_s2_20251024.yml b/data/helios/v2_mozambique_lulc/finetune_s2_20251024.yml new file mode 100644 index 00000000..ffa9463a --- /dev/null +++ b/data/helios/v2_mozambique_lulc/finetune_s2_20251024.yml @@ -0,0 +1,269 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth + init_args: + checkpoint_path: /weka/dfive-default/helios/checkpoints/joer/phase2.0_base_lr0.0001_wd0.02/step667200 + selector: + - encoder + forward_kwargs: + patch_size: 1 + random_initialization: false + embedding_size: null + patch_size: null + autocast_dtype: bfloat16 + decoders: + segment: + - class_path: rslearn.models.upsample.Upsample + init_args: + scale_factor: 1 + mode: bilinear + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 768 + out_channels: 8 + kernel_size: 1 + padding: same + stride: 1 + activation: + class_path: torch.nn.Identity + init_args: {} + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: + - 0 + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + lazy_decode: false + loss_weights: null + trunk: null + lr: 0.0001 + scheduler: + class_path: rslearn.train.scheduler.PlateauScheduler + init_args: + factor: 0.2 + patience: 2 + min_lr: 0 + cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: /weka/dfive-default/rslearn-eai/datasets/crop/mozambique_lulc + inputs: + sentinel2_l2a: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + load_all_item_groups: true + load_all_layers: true + label: + data_type: "raster" + layers: ["label_raster"] + bands: ["label"] + is_target: true + dtype: INT32 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + segment: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 8 + zero_is_invalid: true + metric_kwargs: + average: "micro" + other_metrics: + water_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 1 + water_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 1 + bareground_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 2 + bareground_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 2 + rangeland_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 3 + rangeland_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 3 + floodedvegetation_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 4 + floodedvegetation_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 4 + trees_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 5 + trees_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 5 + cropland_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 6 + cropland_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 6 + buildings_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 7 + buildings_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 7 + input_mapping: + segment: + label: "targets" + batch_size: 4 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + train_config: + transforms: + - class_path: rslearn.train.transforms.pad.Pad + init_args: + size: 31 + mode: "center" + image_selectors: ["sentinel2_l2a", "target/segment/classes", "target/segment/valid"] + - class_path: rslearn.train.transforms.crop.Crop + init_args: + crop_size: 16 + image_selectors: ["sentinel2_l2a", "target/segment/classes", "target/segment/valid"] + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["sentinel2_l2a", "target/segment/classes", "target/segment/valid"] + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + groups: ["gaza"] + tags: + split: "train" + val_config: + patch_size: 16 + groups: ["gaza"] + tags: + split: "test" + test_config: + patch_size: 16 + groups: ["gaza"] + tags: + split: "test" +trainer: + max_epochs: 100 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_loss + mode: min + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 20 + unfreeze_lr_factor: 10 +rslp_project: 2025_09_18_mozambique_lulc +rslp_experiment: mozambique_lulc_helios_base_S2_ts_ws16_ps1_gaza diff --git a/esrun_data/crop/mozambique/dataset.json b/esrun_data/crop/mozambique/dataset.json new file mode 100644 index 00000000..447444e0 --- /dev/null +++ b/esrun_data/crop/mozambique/dataset.json @@ -0,0 +1,184 @@ +{ + "layers": { + "es_output": { + "band_sets": [ + { + "bands": [ + "output" + ], + "dtype": "uint8" + } + ], + "type": "raster" + }, + "label": { + "type": "vector" + }, + "label_raster": { + "band_sets": [ + { + "bands": [ + "label" + ], + "dtype": "int32" + } + ], + "type": "raster" + }, + "output": { + "band_sets": [ + { + "bands": [ + "output" + ], + "dtype": "uint8" + } + ], + "type": "raster" + }, + "prediction_v0": { + "band_sets": [ + { + "bands": [ + "output" + ], + "dtype": "uint8", + "format": { + "name": "geotiff" + } + } + ], + "type": "raster" + }, + "prediction_v1": { + "band_sets": [ + { + "bands": [ + "output" + ], + "dtype": "uint8" + } + ], + "type": "raster" + }, + "sentinel1_ascending": { + "band_sets": [ + { + "bands": [ + "vv", + "vh" + ], + "dtype": "float32" + } + ], + "data_source": { + "cache_dir": "cache/planetary_computer", + "duration": "366d", + "ingest": false, + "name": "rslp.satlas.data_sources.MonthlySentinel1", + "query": { + "sar:instrument_mode": { + "eq": "IW" + }, + "sar:polarizations": { + "eq": [ + "VV", + "VH" + ] + }, + "sat:orbit_state": { + "eq": "ascending" + } + }, + "query_config": { + "max_matches": 12 + }, + "time_offset": "-180d" + }, + "type": "raster" + }, + "sentinel1_descending": { + "band_sets": [ + { + "bands": [ + "vv", + "vh" + ], + "dtype": "float32" + } + ], + "data_source": { + "cache_dir": "cache/planetary_computer", + "duration": "366d", + "ingest": false, + "name": "rslp.satlas.data_sources.MonthlySentinel1", + "query": { + "sar:instrument_mode": { + "eq": "IW" + }, + "sar:polarizations": { + "eq": [ + "VV", + "VH" + ] + }, + "sat:orbit_state": { + "eq": "descending" + } + }, + "query_config": { + "max_matches": 12 + }, + "time_offset": "-180d" + }, + "type": "raster" + }, + "sentinel2": { + "band_sets": [ + { + "bands": [ + "B02", + "B03", + "B04", + "B08" + ], + "dtype": "uint16" + }, + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "zoom_offset": -1 + }, + { + "bands": [ + "B01", + "B09" + ], + "dtype": "uint16", + "zoom_offset": -2 + } + ], + "data_source": { + "cache_dir": "cache/planetary_computer", + "duration": "366d", + "harmonize": true, + "ingest": false, + "max_cloud_cover": 50, + "name": "rslp.satlas.data_sources.MonthlyAzureSentinel2", + "query_config": { + "max_matches": 12 + }, + "sort_by": "eo:cloud_cover", + "time_offset": "-180d" + }, + "type": "raster" + } + } +} diff --git a/esrun_data/crop/mozambique/esrun.yaml b/esrun_data/crop/mozambique/esrun.yaml new file mode 100644 index 00000000..e03a7972 --- /dev/null +++ b/esrun_data/crop/mozambique/esrun.yaml @@ -0,0 +1,55 @@ +inference_results_config: + data_type: RASTER + classification_fields: + - property_name: crop_type_classification + band_index: 1 + allowed_values: + - value: 1 + label: Water + color: [136, 33, 233] + - value: 2 + label: Bare Ground + color: [124, 67, 18] + - value: 3 + label: Rangeland + color: [51, 160, 44] + - value: 4 + label: Flooded Vegetation + color: [169, 214, 146] + - value: 5 + label: Trees + color: [240, 159, 28] + - value: 6 + label: Cropland + color: [251, 154, 153] + - value: 7 + label: Buildings + color: [31, 234, 146] + +partition_strategies: + partition_request_geometry: + class_path: esrun.runner.tools.partitioners.grid_partitioner.GridPartitioner + init_args: + grid_size: 0.1 + + prepare_window_geometries: + class_path: esrun.runner.tools.partitioners.grid_partitioner.GridPartitioner + init_args: + grid_size: 64 + output_projection: + class_path: rslearn.utils.geometry.Projection + init_args: + crs: EPSG:3857 + x_resolution: 10 + y_resolution: -10 + use_utm: true + +postprocessing_strategies: + process_dataset: + class_path: esrun.runner.tools.postprocessors.combine_geotiff.CombineGeotiff + + process_partition: + class_path: esrun.runner.tools.postprocessors.combine_geotiff.CombineGeotiff + + process_window: + class_path: esrun.runner.tools.postprocessors.noop_raster.NoopRaster diff --git a/esrun_data/crop/mozambique/gaza_geometry.json b/esrun_data/crop/mozambique/gaza_geometry.json new file mode 100644 index 00000000..e0848d0c --- /dev/null +++ b/esrun_data/crop/mozambique/gaza_geometry.json @@ -0,0 +1,191 @@ +{ + "geometry": { + "features": [ + { + "geometry": { + "coordinates": [ + [ + [ + 32.38167662388802, + -21.27408677170535 + ], + [ + 32.45813769443847, + -21.2933563863121 + ], + [ + 32.62344095182011, + -21.33351302200554 + ], + [ + 32.79617837829191, + -21.385222147670355 + ], + [ + 33.030607742789364, + -21.511545253042982 + ], + [ + 33.147822425038086, + -21.40245446502164 + ], + [ + 33.3020522701022, + -21.431170479655627 + ], + [ + 33.33906743291759, + -21.511545253042982 + ], + [ + 33.29588307629963, + -21.5918756113073 + ], + [ + 33.258867913484245, + -21.79250631724419 + ], + [ + 33.38225178953554, + -21.95853078614493 + ], + [ + 33.45628211516631, + -22.301413291161854 + ], + [ + 33.369913401930404, + -22.58081167011402 + ], + [ + 33.622850347835545, + -22.802786104769332 + ], + [ + 33.88812568134582, + -23.14357944854419 + ], + [ + 33.96333585869647, + -23.46542008102314 + ], + [ + 33.99415258048534, + -23.66314880839998 + ], + [ + 33.97703217949152, + -24.026438393135226 + ], + [ + 34.02839338247298, + -24.05458199971527 + ], + [ + 34.038665623069264, + -24.31071668189465 + ], + [ + 34.1756288310198, + -24.475992751117644 + ], + [ + 34.422162605330755, + -24.42612087853377 + ], + [ + 34.4564034073184, + -24.45105928210875 + ], + [ + 34.39476996374065, + -24.525844870463995 + ], + [ + 34.47694788851097, + -24.634827369859586 + ], + [ + 34.391345883541895, + -24.70328206211508 + ], + [ + 34.490644209306026, + -24.87736051245625 + ], + [ + 33.14155661099326, + -25.40740049895447 + ], + [ + 32.97837114264776, + -25.2568116984173 + ], + [ + 32.88817313399803, + -25.09355132588769 + ], + [ + 32.86041990056735, + -24.854547902010303 + ], + [ + 32.72859204177159, + -24.58984801919847 + ], + [ + 32.61064079969118, + -24.432022023771314 + ], + [ + 32.2429104567346, + -24.28664743325159 + ], + [ + 31.92374827228171, + -24.299295327575926 + ], + [ + 31.84048857198965, + -24.014413763057547 + ], + [ + 31.7086607131939, + -23.874907789016 + ], + [ + 31.65315424633253, + -23.627232273722857 + ], + [ + 31.528264695894446, + -23.487310629228134 + ], + [ + 31.535203004252114, + -23.213398846065775 + ], + [ + 31.250732361587584, + -22.401137684084613 + ], + [ + 32.38167662388802, + -21.27408677170535 + ] + ] + ], + "type": "Polygon" + }, + "properties": { + "oe_end_time": "2023-03-31T00:00:00Z", + "oe_start_time": "2023-03-01T00:00:00Z" + }, + "type": "Feature" + } + ], + "type": "FeatureCollection" + }, + "model_pipeline_id": "5224c4ea-7182-41a1-9497-2e59013ccdb6", + "workflow_type": "PREDICTION" +} diff --git a/esrun_data/crop/mozambique/model.yaml b/esrun_data/crop/mozambique/model.yaml new file mode 100644 index 00000000..a018c65d --- /dev/null +++ b/esrun_data/crop/mozambique/model.yaml @@ -0,0 +1,270 @@ +# FIXME ${TRAINER_DATA_PATH} +# FIXME ${WANDB_PROJECT}: wandb project for the trainer to log metrics to +# FIXME ${WANDB_NAME}: wandb name for the trainer to log metrics to +# FIXME ${WANDB_ENTITY}: wandb entity for the trainer to log metrics to +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslp.helios.model.Helios + init_args: + # this EXTRA_FILES_PATH needs to be manually copied into the + # GCS bucket. It is the helios checkpoint which was finetuned + # to obtain the checkpoint being referenced in ES Run. + checkpoint_path: ${EXTRA_FILES_PATH}/helios/step300000 + random_initialization: true + selector: ["encoder"] + forward_kwargs: + patch_size: 1 + decoders: + crop_type_classification: + - class_path: rslp.crop.kenya_nandi.train.SegmentationPoolingDecoder + init_args: + in_channels: 768 + out_channels: 8 + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + lr: 0.0001 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: ${DATASET_PATH} + inputs: + sentinel2_l2a: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + load_all_item_groups: true + load_all_layers: true + sentinel1: + data_type: "raster" + # check which layer (ascending vs descending) was + # actually ingested in rslearn + layers: ["sentinel1_descending"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + load_all_item_groups: true + load_all_layers: true + label: + data_type: "raster" + layers: ["label_raster"] + bands: ["label"] + is_target: true + dtype: INT32 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + crop_type_classification: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 8 + zero_is_invalid: true + metric_kwargs: + average: "micro" + other_metrics: + water_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 1 + water_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 1 + bareground_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 2 + bareground_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 2 + rangeland_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 3 + rangeland_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 3 + floodedvegetation_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 4 + floodedvegetation_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 4 + trees_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 5 + trees_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 5 + cropland_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 6 + cropland_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 6 + buildings_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 7 + buildings_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 7 + input_mapping: + crop_type_classification: + label: "targets" + batch_size: 32 + num_workers: ${NUM_WORKERS} + default_config: + transforms: + - class_path: rslearn.train.transforms.pad.Pad + init_args: + size: 4 + mode: "center" + image_selectors: ["sentinel2_l2a", "target/crop_type_classification/classes", "target/crop_type_classification/valid"] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + train_config: + groups: ["gaza"] + tags: + split: "train" + val_config: + groups: ["gaza"] + tags: + split: "test" + test_config: + groups: ["gaza"] + tags: + split: "test" + predict_config: + groups: ["gaza"] + load_all_patches: true + skip_targets: true + overlap_ratio: 0.75 + patch_size: 4 + transforms: + - class_path: rslearn.train.transforms.pad.Pad + init_args: + size: 4 + mode: "center" + image_selectors: ["sentinel2_l2a"] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] +trainer: + max_epochs: 100 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_loss + mode: min + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 2 + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: ${DATASET_PATH} + output_layer: ${PREDICTION_OUTPUT_LAYER} + selector: ["crop_type_classification"] + merger: + class_path: rslearn.train.prediction_writer.RasterMerger + init_args: + padding: 2 diff --git a/esrun_data/crop/mozambique/model_20251024.yaml b/esrun_data/crop/mozambique/model_20251024.yaml new file mode 100644 index 00000000..dfd215f1 --- /dev/null +++ b/esrun_data/crop/mozambique/model_20251024.yaml @@ -0,0 +1,260 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.olmoearth_pretrain.model.OlmoEarth + init_args: + checkpoint_path: ${EXTRA_FILES_PATH} + random_initialization: true + selector: ["encoder"] + forward_kwargs: + patch_size: 1 + decoders: + class: + - class_path: rslp.nandi.train.SegmentationPoolingDecoder + init_args: + in_channels: 768 + out_channels: 8 + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + lr: 0.0001 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: ${DATASET_PATH} + inputs: + sentinel2_l2a: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + load_all_item_groups: true + load_all_layers: true + label: + data_type: "raster" + layers: ["label_raster"] + bands: ["label"] + is_target: true + dtype: INT32 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + class: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 8 + zero_is_invalid: true + metric_kwargs: + average: "micro" + other_metrics: + coffee_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 1 + coffee_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 1 + trees_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 2 + trees_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 2 + grassland_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 3 + grassland_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 3 + maize_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 4 + maize_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 4 + sugarcane_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 5 + sugarcane_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 5 + tea_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 6 + tea_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 6 + vegetables_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 7 + vegetables_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 7 + input_mapping: + class: + label: "targets" + batch_size: 32 + num_workers: ${NUM_WORKERS} + default_config: + transforms: + - class_path: rslearn.train.transforms.pad.Pad + init_args: + size: 4 + mode: "center" + image_selectors: ["sentinel2_l2a", "target/crop_type_classification/classes", "target/crop_type_classification/valid"] + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + train_config: + groups: ["groundtruth_polygon_split_window_32", "worldcover_window_32"] + tags: + split: "train" + val_config: + groups: ["groundtruth_polygon_split_window_32", "worldcover_window_32"] + tags: + split: "val" + test_config: + groups: ["groundtruth_polygon_split_window_32", "worldcover_window_32"] + tags: + split: "val" + predict_config: + groups: ["nandi_county"] + load_all_patches: true + skip_targets: true + overlap_ratio: 0.75 + patch_size: 4 + transforms: + - class_path: rslearn.train.transforms.pad.Pad + init_args: + size: 4 + mode: "center" + image_selectors: ["sentinel2_l2a"] + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] +trainer: + max_epochs: 100 + default_root_dir: ${TRAINER_DATA_PATH} + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: ${WANDB_PROJECT} + name: ${WANDB_NAME} + entity: ${WANDB_ENTITY} + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_loss + mode: min + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 2 + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: placeholder + output_layer: ${PREDICTION_OUTPUT_LAYER} + selector: ["class"] + merger: + class_path: rslearn.train.prediction_writer.RasterMerger + init_args: + padding: 2 diff --git a/esrun_data/crop/mozambique/prediction_request_geometry.json b/esrun_data/crop/mozambique/prediction_request_geometry.json new file mode 100644 index 00000000..14a605ed --- /dev/null +++ b/esrun_data/crop/mozambique/prediction_request_geometry.json @@ -0,0 +1,39 @@ +{ + "features": [ + { + "geometry": { + "coordinates": [ + [ + [ + 33.37, + -25.16 + ], + [ + 33.37, + -24.85 + ], + [ + 33.82, + -24.85 + ], + [ + 33.82, + -25.16 + ], + [ + 33.37, + -25.16 + ] + ] + ], + "type": "Polygon" + }, + "properties": { + "es_end_time": "2023-03-31T00:00:00+00:00", + "es_start_time": "2023-03-01T00:00:00+00:00" + }, + "type": "Feature" + } + ], + "type": "FeatureCollection" +} diff --git a/rslp/mozambique/README.md b/rslp/mozambique/README.md new file mode 100644 index 00000000..09dc0b21 --- /dev/null +++ b/rslp/mozambique/README.md @@ -0,0 +1,43 @@ +# Mozambique LULC and Crop Type Classification + +This project has two main tasks: + 1. Land Use/Land Cover (LULC) and cropland classification + 2. Crop type classification + +The annotations come from field surveys across three provinces in Mozambique: Gaza, Zambezia, and Manica. + +For LULC classification, the train/test splits are: +- Gaza: 2,262 / 970 +- Manica: 1,917 / 822 +- Zambezia: 1,225 / 525 + +## LULC Classification + +#### 2025-11-05 + +Updates so that it works with all the changes. Also now a forward pass does a patch not a pixel; this makes inference far faster. + +``` +python -m rslp.main olmoearth_pretrain launch_finetune --image_name gabrielt/rslpomp_20251027b --config_paths+=data/helios/v2_mozambique_lulc/finetune_s2_20251024.yml --cluster+=ai2/saturn --rslp_project 2025_09_18_mozambique_lulc --experiment_id mozambique_lulc_helios_base_S2_ts_ws4_ps1_gaza_20251105_saturn_b +``` + +#### 2025-10-23 + +Update S1 and S2 training scripts to run with all the updates. This also requires running `python -m rslp.main olmoearth_pretrain` instead of `python -m rslp.main helios`: + +``` +python -m rslp.main olmoearth_pretrain launch_finetune --image_name favyen/favyen/rslpomp20251022a --config_paths+=data/helios/v2_mozambique_lulc/finetune_s2.yaml --cluster+=ai2/neptune --rslp_project 2025_09_18_mozambique_lulc --experiment_id mozambique_lulc_helios_base_S2_ts_ws4_ps1_gaza_20251023 +``` + +Also - the geometry for Gaza province was enormous (hundreds of thousands of points). I have drawn a cruder polygon around the province for the prediction request geometry to try and keep things manageable. + +#### Original commands +``` +python /weka/dfive-default/yawenz/rslearn_projects/rslp/crop/mozambique/create_windows_for_lulc.py --gpkg_dir /weka/dfive-default/yawenz/datasets/mozambique/train_test_samples --ds_path /weka/dfive-default/rslearn-eai/datasets/crop/mozambique_lulc --window_size 32 + +export DATASET_PATH=/weka/dfive-default/rslearn-eai/datasets/crop/mozambique_lulc +rslearn dataset prepare --root $DATASET_PATH --workers 64 --no-use-initial-job --retry-max-attempts 8 --retry-backoff-seconds 60 +python -m rslp.main common launch_data_materialization_jobs --image favyen/rslp_image --ds_path $DATASET_PATH --clusters+=ai2/neptune-cirrascale --num_jobs 5 + +python -m rslp.main helios launch_finetune --image_name favyen/rslphelios10 --config_paths+=data/helios/v2_mozambique_lulc/finetune_s1_s2.yaml --cluster+=ai2/neptune --rslp_project 2025_09_18_mozambique_lulc --experiment_id mozambique_lulc_helios_base_S1_S2_ts_ws4_ps1_gaza +``` diff --git a/rslp/mozambique/create_label_raster.py b/rslp/mozambique/create_label_raster.py new file mode 100644 index 00000000..89eaf79c --- /dev/null +++ b/rslp/mozambique/create_label_raster.py @@ -0,0 +1,89 @@ +"""Create label_raster from label. + +If you run this, you will need to update the config.json for the dataset +to include the following entry: + +"label_raster": { + "band_sets": [ + { + "bands": [ + "label" + ], + "dtype": "int32" + } + ], + "type": "raster" + }, +""" + +import argparse +import multiprocessing + +import numpy as np +import tqdm +from rslearn.dataset.dataset import Dataset +from rslearn.dataset.window import Window +from rslearn.utils.raster_format import GeotiffRasterFormat +from rslearn.utils.vector_format import GeojsonVectorFormat +from upath import UPath + +CLASS_NAMES = [ + "invalid", + "Water", + "Bare Ground", + "Rangeland", + "Flooded Vegetation", + "Trees", + "Cropland", + "Buildings", +] +PROPERTY_NAME = "category" +BAND_NAME = "label" + + +def create_label_raster(window: Window) -> None: + """Create label raster for the given window.""" + label_dir = window.get_layer_dir("label") + features = GeojsonVectorFormat().decode_vector( + label_dir, window.projection, window.bounds + ) + class_name = features[0].properties[PROPERTY_NAME] + class_id = CLASS_NAMES.index(class_name) + + # Draw the class_id in the middle 1x1 of the raster. + raster = np.zeros( + (1, window.bounds[3] - window.bounds[1], window.bounds[2] - window.bounds[0]), + dtype=np.uint8, + ) + raster[:, raster.shape[1] // 2, raster.shape[2] // 2] = class_id + raster_dir = window.get_raster_dir("label_raster", [BAND_NAME]) + GeotiffRasterFormat().encode_raster( + raster_dir, window.projection, window.bounds, raster + ) + window.mark_layer_completed("label_raster") + + +if __name__ == "__main__": + multiprocessing.set_start_method("forkserver") + parser = argparse.ArgumentParser() + parser.add_argument( + "--ds_path", + type=str, + required=True, + help="Path to the dataset", + ) + parser.add_argument( + "--workers", + type=int, + default=64, + help="Number of worker processes to use", + ) + args = parser.parse_args() + + dataset = Dataset(UPath(args.ds_path)) + windows = dataset.load_windows(workers=args.workers, show_progress=True) + p = multiprocessing.Pool(args.workers) + outputs = p.imap_unordered(create_label_raster, windows) + for _ in tqdm.tqdm(outputs, total=len(windows)): + pass + p.close() diff --git a/rslp/mozambique/create_windows_for_lulc.py b/rslp/mozambique/create_windows_for_lulc.py new file mode 100644 index 00000000..2364a427 --- /dev/null +++ b/rslp/mozambique/create_windows_for_lulc.py @@ -0,0 +1,248 @@ +"""Create windows for crop type mapping from GPKG files (fixed splits).""" + +import argparse +import multiprocessing +from collections.abc import Iterable +from datetime import datetime, timezone +from pathlib import Path + +import geopandas as gpd +import shapely +import tqdm +from rslearn.const import WGS84_PROJECTION +from rslearn.dataset import Window +from rslearn.utils import Projection, STGeometry, get_utm_ups_crs +from rslearn.utils.feature import Feature +from rslearn.utils.mp import star_imap_unordered +from rslearn.utils.vector_format import GeojsonVectorFormat +from upath import UPath + +from rslp.utils.windows import calculate_bounds + +WINDOW_RESOLUTION = 10 +LABEL_LAYER = "label" + +CLASS_MAP = { + 0: "Water", + 1: "Bare Ground", + 2: "Rangeland", + 3: "Flooded Vegetation", + 4: "Trees", + 5: "Cropland", + 6: "Buildings", +} + +# Per-province temporal coverage (UTC) +PROVINCE_TIME = { + "gaza": ( + datetime(2024, 10, 23, tzinfo=timezone.utc), + datetime(2025, 5, 7, tzinfo=timezone.utc), + ), + "manica": ( + datetime(2024, 11, 23, tzinfo=timezone.utc), + datetime(2025, 6, 7, tzinfo=timezone.utc), + ), + "zambezia": ( + datetime(2024, 11, 23, tzinfo=timezone.utc), + datetime(2025, 6, 7, tzinfo=timezone.utc), + ), +} + + +def process_gpkg(gpkg_path: UPath) -> gpd.GeoDataFrame: + """Load a GPKG and ensure lon/lat in WGS84; expect 'fid' and 'class' columns.""" + gdf = gpd.read_file(str(gpkg_path)) + + # Normalize CRS to WGS84 + if gdf.crs is None: + gdf = gdf.set_crs("EPSG:4326", allow_override=True) + else: + gdf = gdf.to_crs("EPSG:4326") + + required_cols = {"class", "geometry"} + missing = [c for c in required_cols if c not in gdf.columns] + if missing: + raise ValueError(f"{gpkg_path}: missing required column(s): {missing}") + + return gdf + + +def iter_points(gdf: gpd.GeoDataFrame) -> Iterable[tuple[int, float, float, int]]: + """Yield (fid, latitude, longitude, category) per feature using centroid for polygons.""" + for fid, row in gdf.iterrows(): + geom = row.geometry + if geom is None or geom.is_empty: + continue + if isinstance(geom, shapely.Point): + pt = geom + else: + pt = geom.centroid + lon, lat = float(pt.x), float(pt.y) + category = int(row["class"]) + yield fid, lat, lon, category + + +def create_window( + rec: tuple[int, float, float, int], + ds_path: UPath, + group_name: str, + split: str, + window_size: int, + start_time: datetime, + end_time: datetime, +) -> None: + """Create a single window and write label layer.""" + fid, latitude, longitude, category_id = rec + category_label = CLASS_MAP.get(category_id, f"Unknown_{category_id}") + + # Geometry/projection + src_point = shapely.Point(longitude, latitude) + src_geometry = STGeometry(WGS84_PROJECTION, src_point, None) + dst_crs = get_utm_ups_crs(longitude, latitude) + dst_projection = Projection(dst_crs, WINDOW_RESOLUTION, -WINDOW_RESOLUTION) + dst_geometry = src_geometry.to_projection(dst_projection) + bounds = calculate_bounds(dst_geometry, window_size) + + # Group = province name; split is taken from file name (train/test) + group = group_name + window_name = f"{fid}_{latitude:.6f}_{longitude:.6f}" + + window = Window( + path=Window.get_window_root(ds_path, group, window_name), + group=group, + name=window_name, + projection=dst_projection, + bounds=bounds, + time_range=(start_time, end_time), + options={ + "split": split, # 'train' or 'test' as provided + "category_id": category_id, + "category": category_label, + "fid": fid, + "source": "gpkg", + }, + ) + window.save() + + # Label layer (same as before, using window geometry) + feature = Feature( + window.get_geometry(), + { + "category_id": category_id, + "category": category_label, + "fid": fid, + "split": split, + }, + ) + layer_dir = window.get_layer_dir(LABEL_LAYER) + GeojsonVectorFormat().encode_vector(layer_dir, [feature]) + window.mark_layer_completed(LABEL_LAYER) + + +def create_windows_from_gpkg( + gpkg_path: UPath, + ds_path: UPath, + group_name: str, + split: str, + window_size: int, + max_workers: int, + start_time: datetime, + end_time: datetime, +) -> None: + """Create windows from a single GPKG file.""" + gdf = process_gpkg(gpkg_path) + records = list(iter_points(gdf)) + + jobs = [ + dict( + rec=rec, + ds_path=ds_path, + group_name=group_name, + split=split, + window_size=window_size, + start_time=start_time, + end_time=end_time, + ) + for rec in records + ] + + print( + f"[{group_name}:{split}] file={gpkg_path.name} features={len(jobs)} " + f"time={start_time.date()}→{end_time.date()}" + ) + + if max_workers <= 1: + for kw in tqdm.tqdm(jobs): + create_window(**kw) + else: + p = multiprocessing.Pool(max_workers) + outputs = star_imap_unordered(p, create_window, jobs) + for _ in tqdm.tqdm(outputs, total=len(jobs)): + pass + p.close() + + +if __name__ == "__main__": + multiprocessing.set_start_method("forkserver", force=True) + + parser = argparse.ArgumentParser(description="Create windows from GPKG files") + parser.add_argument( + "--gpkg_dir", + type=str, + required=True, + help="Directory containing gaza_[train|test].gpkg, manica_[train|test].gpkg, zambezia_[train|test].gpkg", + ) + parser.add_argument( + "--ds_path", + type=str, + required=True, + help="Path to the dataset root", + ) + parser.add_argument( + "--window_size", + type=int, + default=1, + help="Window size (pixels per side in projected grid)", + ) + parser.add_argument( + "--max_workers", + type=int, + default=32, + help="Worker processes (set 1 for single-process)", + ) + args = parser.parse_args() + + gpkg_dir = Path(args.gpkg_dir) + ds_path = UPath(args.ds_path) + + expected = [ + ("gaza", "train", gpkg_dir / "gaza_train.gpkg"), + ("gaza", "test", gpkg_dir / "gaza_test.gpkg"), + ("manica", "train", gpkg_dir / "manica_train.gpkg"), + ("manica", "test", gpkg_dir / "manica_test.gpkg"), + ("zambezia", "train", gpkg_dir / "zambezia_train.gpkg"), + ("zambezia", "test", gpkg_dir / "zambezia_test.gpkg"), + ] + + # Basic checks + for province, _, path in expected: + if province not in PROVINCE_TIME: + raise ValueError(f"Unknown province '{province}'") + if not path.exists(): + raise FileNotFoundError(f"Missing expected file: {path}") + + # Run per file + for province, split, path in expected: + start_time, end_time = PROVINCE_TIME[province] + create_windows_from_gpkg( + gpkg_path=UPath(path), + ds_path=ds_path, + group_name=province, # group == province + split=split, # honor provided split + window_size=args.window_size, + max_workers=args.max_workers, + start_time=start_time, + end_time=end_time, + ) + + print("Done.")