Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
fe35e7e
Refactor case_studies/weak_lensing
timwhite0 Dec 27, 2025
fec9582
Fix refactor bug and move wl/cached_dataset
timwhite0 Dec 27, 2025
775f9a4
Update dc2 config
timwhite0 Dec 27, 2025
fd2c5a5
Allow count splits instead of percent splits
timwhite0 Dec 28, 2025
9fca28c
Remove deprecated files for generating encoder input
timwhite0 Dec 29, 2025
22b5a5b
Script for running AnaCal
timwhite0 Dec 29, 2025
80c31f0
Update anacal script
timwhite0 Dec 30, 2025
a0e758e
Ignore anacal psf/mask/etc during NPE training
timwhite0 Jan 8, 2026
6d1fbac
Update config
timwhite0 Jan 8, 2026
5e4d2b8
Move some descwl notebooks to deprecated folder
timwhite0 Jan 8, 2026
d851701
Deprecated scripts
timwhite0 Jan 8, 2026
ec2d8a3
Rename and reorganize descwl
timwhite0 Jan 8, 2026
8fde2e0
Rename and reorganize DC2
timwhite0 Jan 8, 2026
f552bc8
Rename descwl images notebook
timwhite0 Jan 8, 2026
c48bc54
Rename config
timwhite0 Jan 8, 2026
25f9ec1
Update config paths
timwhite0 Jan 8, 2026
cb63519
Move descwl ckpt
timwhite0 Jan 8, 2026
7f53ad5
Update descwl NPE credible intervals script
timwhite0 Jan 9, 2026
e852613
Update descwl NPE credible intervals script
timwhite0 Jan 9, 2026
c51e897
Speed up AnaCal and run on same test set as NPE
timwhite0 Jan 9, 2026
abb7470
Get rid of as_completed in AnaCal multiprocessing
timwhite0 Jan 9, 2026
36da8fc
Update descwl credible intervals notebook
timwhite0 Jan 9, 2026
7bde37b
Fix and run all DC2 notebooks
timwhite0 Jan 9, 2026
820b38f
Update descwl image notebook
timwhite0 Jan 9, 2026
97758ed
Update descwl scatterplots
timwhite0 Jan 9, 2026
7726f7e
Update README
timwhite0 Jan 9, 2026
18d50bb
ruff + update gitignore
timwhite0 Jan 9, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@ data/dc2/merged_catalog
case_studies/sdss_galaxies/models/simulated_blended_galaxies.pt
case_studies/*/data/
venv
nohup.out
*.out
multirun
DC2_*.out
input_images/
WeakLensingResults*
ResultsDC2*
Results*
case_studies/weak_lensing/**/*.pt
43 changes: 34 additions & 9 deletions case_studies/weak_lensing/README.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,44 @@
### Neural posterior estimation for tomographic field-level weak lensing inference
#### Tim White, Shreyas Chandrashekaran, Dingrui Tao, Camille Avestruz, and Jeffrey Regier
#### with assistance from Steve Fan and Tahseen Younus
## Neural posterior estimation for field-level weak lensing inference
#### Tim White, Shreyas Chandrashekaran, Dingrui Tao, Camille Avestruz, and Jeffrey Regier, with assistance from Steve Fan and Tahseen Younus

In this case study, we use neural posterior estimation to infer tomographic shear and convergence maps from LSST-like images.
In this case study, we use neural posterior estimation (NPE) to infer weak lensing shear and convergence from LSST-like images. We use NPE to (1) infer tomographic mass maps for the [DC2 Simulated Sky Survey](https://data.lsstdesc.org/doc/dc2_sim_sky_survey) and (2) infer constant shear from images generated with the [`descwl-shear-sims` package](https://github.com/timwhite0/descwl-shear-sims).

To train the encoder on DC2 images, run
## DC2

### Generate catalog

```
nohup python -u case_studies/weak_lensing/dc2/generate_catalog.py &> generate_catalog.out &
```

### Generate mass maps and train MassMapEncoder

```
nohup bliss -cp <path>/bliss/case_studies/weak_lensing/dc2 -cn config_train_npe.yaml mode=train &> train_dc2.out &
```

### Notebooks

- **In `dc2/results`**: `credibleintervals.ipynb`, `posteriormeanmaps.ipynb`
- **In `dc2/exploratory`**: `dc2imageandmaps.ipynb`, `ellipticity.ipynb`, `galaxyproperties.ipynb`, `twopoint.ipynb`

## descwl-shear-sims

### Train ScalarShearEncoder

```
nohup bliss -cp <path>/bliss/case_studies/weak_lensing/dc2 -cn config_dc2.yaml mode=train &> train_on_dc2.out &
nohup bliss -cp <path>/bliss/case_studies/weak_lensing/descwl -cn config_train_npe.yaml mode=train &> train_descwl.out &
```

To train the encoder on descwl-shear-sims images, run
### Run [AnaCal](https://github.com/mr-superonion/AnaCal)

Configure settings in `descwl/config_run_anacal.yaml`.

```
nohup bliss -cp <path>/bliss/case_studies/weak_lensing/descwl -cn config_descwl.yaml mode=train &> train_on_descwl.out &
nohup python -u case_studies/weak_lensing/descwl/run_anacal.py &> run_anacal.out &
```

See `dc2/notebooks` and `descwl/notebooks` for some exploratory plots and our most recent results.
### Notebooks

- **In `descwl/results`**: `compute_npe_credibleintervals.py`, `credibleintervals.ipynb`, `scatterplots.ipynb`
- **In `descwl/exploratory`**: `images.ipynb`
267 changes: 267 additions & 0 deletions case_studies/weak_lensing/cached_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
"""Cached dataset infrastructure for weak lensing case study.

This is a standalone copy of the data loading infrastructure from bliss.cached_dataset,
modified to remove GlobalEnv dependency.
"""

import functools
import logging
import math
import operator
import os
import pathlib
import random
import re
import warnings
from typing import List

import pytorch_lightning as pl
import torch
from torch import distributed as dist
from torch.utils.data import DataLoader, Dataset, DistributedSampler, Sampler
from torchvision import transforms

# Suppress pytorch_lightning warnings
warnings.filterwarnings(
"ignore", ".*does not have many workers which may be a bottleneck.*", UserWarning
)
warnings.filterwarnings("ignore", ".*Total length of .* across ranks is zero.*", UserWarning)


class ChunkingSampler(Sampler):
"""Sampler that respects chunked data ordering."""

def __init__(self, dataset: Dataset) -> None:
super().__init__()
assert isinstance(dataset, ChunkingDataset), "dataset should be ChunkingDataset"
self.dataset = dataset

def __len__(self):
return len(self.dataset)

def __iter__(self):
return iter(self.dataset.get_chunked_indices())


class DistributedChunkingSampler(DistributedSampler):
"""Distributed sampler that respects chunked data ordering."""

def __init__(
self,
dataset: Dataset,
num_replicas: int | None = None,
rank: int | None = None,
shuffle: bool = False,
seed: int = 0,
drop_last: bool = False,
) -> None:
assert isinstance(dataset, ChunkingDataset), "dataset should be ChunkingDataset"
assert not shuffle, "you should not use shuffle"
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)

def __iter__(self):
pre_indices = super().__iter__()
chunked_indices = self.dataset.get_chunked_indices()
return iter([chunked_indices[i] for i in pre_indices])


class ChunkingDataset(Dataset):
"""Dataset that loads data from chunked .pt files."""

def __init__(self, file_paths, shuffle=False, transform=None, seed=0) -> None:
super().__init__()
self.file_paths = file_paths
self.shuffle = shuffle
self.transform = transform
self.seed = seed
self.current_epoch = 0 # Updated by datamodule

self.accumulated_file_sizes = torch.zeros(len(self.file_paths), dtype=torch.int64)
for i, file_path in enumerate(self.file_paths):
file_size_match = re.search(r"size_(\d+)", file_path)
if file_size_match:
cached_data_len = int(file_size_match.group(1))
else:
if i == 0:
logger = logging.getLogger("ChunkingDataset")
warning_msg = (
"WARNING: add postfix '_size_<chunk size>' to file name; "
"otherwise it'll be very slow \n"
)
logger.warning(warning_msg)
with open(file_path, "rb") as f:
cached_data_len = len(torch.load(f, weights_only=False))

if i == 0:
self.accumulated_file_sizes[i] = cached_data_len
else:
self.accumulated_file_sizes[i] = (
self.accumulated_file_sizes[i - 1] + cached_data_len
)

self.buffered_file_index = None
self.buffered_data = None

def __len__(self):
return self.accumulated_file_sizes[-1].item()

def __getitem__(self, index):
converted_index = (self.accumulated_file_sizes <= index).sum().item()
converted_sub_index = (index - self.accumulated_file_sizes[converted_index]).item()
if self.buffered_file_index != converted_index:
self.buffered_file_index = converted_index
with open(self.file_paths[converted_index], "rb") as f:
self.buffered_data = torch.load(f, weights_only=False)
output_data = self.buffered_data[converted_sub_index]
return self.transform(output_data)

def get_chunked_indices(self):
"""Get indices respecting chunk boundaries, with optional shuffling."""
accumulated_file_sizes_list = self.accumulated_file_sizes.tolist()

output_list = []
if self.shuffle:
# Use seed + epoch for reproducible shuffling
epoch_seed = self.seed + self.current_epoch
logger = logging.getLogger("ChunkingDataset")
logger.info(
"INFO: seed is %d; current epoch is %d; epoch_seed is set to %d",
self.seed,
self.current_epoch,
epoch_seed,
)
right_shift_list = [0, *accumulated_file_sizes_list[:-1]]
for start, end in zip(right_shift_list, accumulated_file_sizes_list, strict=True):
rng = random.Random(epoch_seed)
output_list.append(rng.sample(range(start, end), end - start))
random.Random(epoch_seed).shuffle(output_list)
return functools.reduce(operator.iadd, output_list, [])

return list(range(0, len(self)))


class CachedSimulatedDataModule(pl.LightningDataModule):
"""DataModule for loading cached simulation data from .pt files."""

def __init__(
self,
splits: str,
batch_size: int,
num_workers: int,
cached_data_path: str,
train_transforms: List,
nontrain_transforms: List,
subset_fraction: float = None,
shuffle_file_order: bool = True,
seed: int = 0,
splits_type: str = "percent",
):
super().__init__()

self.splits = splits
self.batch_size = batch_size
self.num_workers = num_workers
self.cached_data_path = pathlib.Path(cached_data_path)
self.train_transforms = train_transforms
self.nontrain_transforms = nontrain_transforms
self.subset_fraction = subset_fraction
self.shuffle_file_order = shuffle_file_order
self.seed = seed
self.splits_type = splits_type

self.file_paths = None
self.slices = None
self.train_dataset = None
self.val_dataset = None
self.test_dataset = None
self.predict_dataset = None

def setup(self, stage: str) -> None:
if self.file_paths is None or self.slices is None:
self._load_file_paths_and_slices()

if stage == "fit":
self.train_dataset = self._get_dataset(
self.file_paths[self.slices[0]], self.train_transforms, shuffle=True
)
self.val_dataset = self._get_dataset(
self.file_paths[self.slices[1]], self.nontrain_transforms
)
return None

if stage == "validate":
if self.val_dataset is None:
self.val_dataset = self._get_dataset(
self.file_paths[self.slices[1]], self.nontrain_transforms
)
return None

if stage == "test":
self.test_dataset = self._get_dataset(
self.file_paths[self.slices[2]], self.nontrain_transforms
)
return None

if stage == "predict":
self.predict_dataset = self._get_dataset(self.file_paths, self.nontrain_transforms)
return None

raise RuntimeError(f"setup skips stage {stage}")

def _load_file_paths_and_slices(self):
file_names = [
f for f in sorted(os.listdir(str(self.cached_data_path))) if f.endswith(".pt")
]
if self.shuffle_file_order:
random.shuffle(file_names)
if self.subset_fraction:
file_names = file_names[: math.ceil(len(file_names) * self.subset_fraction)]
self.file_paths = [os.path.join(str(self.cached_data_path), f) for f in file_names]

self.slices = self.parse_slices(self.splits, len(self.file_paths), self.splits_type)

def _percent_to_idx(self, x, length):
"""Converts string in percent to an integer index."""
return int(float(x.strip()) / 100 * length) if x.strip() else None

def _count_to_idx(self, x):
"""Converts string count to an integer index."""
return int(x.strip()) if x.strip() else None

def parse_slices(self, splits: str, length: int, splits_type: str = "percent"):
slices = [slice(0, 0) for _ in range(3)]
for i, data_split in enumerate(splits.split("/")):
if splits_type == "percent":
slices[i] = slice(
*(self._percent_to_idx(val, length) for val in data_split.split(":"))
)
else: # count
slices[i] = slice(*(self._count_to_idx(val) for val in data_split.split(":")))
return slices

def _get_dataset(self, sub_file_paths, defined_transforms, shuffle: bool = False):
assert sub_file_paths, "No cached data found"
transform = transforms.Compose(defined_transforms)
return ChunkingDataset(sub_file_paths, shuffle=shuffle, transform=transform, seed=self.seed)

def _get_dataloader(self, my_dataset):
distributed_is_used = dist.is_available() and dist.is_initialized()
sampler_type = DistributedChunkingSampler if distributed_is_used else ChunkingSampler
return DataLoader(
my_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
sampler=sampler_type(my_dataset),
)

def train_dataloader(self):
return self._get_dataloader(self.train_dataset)

def val_dataloader(self):
return self._get_dataloader(self.val_dataset)

def test_dataloader(self):
return self._get_dataloader(self.test_dataset)

def predict_dataloader(self):
return self._get_dataloader(self.predict_dataset)
Loading