Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
48cadb0
Update figure: DC2 ground truth maps
timwhite0 Oct 21, 2025
09ce5e9
Update figure: descwl NPE credible intervals
timwhite0 Oct 21, 2025
0d99a24
Update figure: NPE credible intervals and coverage probs
timwhite0 Oct 21, 2025
cba8a6a
Switch to serif font: DC2 ground truth maps
timwhite0 Oct 21, 2025
b78567f
Update figure: DC2 ground truth maps
timwhite0 Oct 21, 2025
cfc85a8
Update figure: DC2 posterior mean maps
timwhite0 Oct 21, 2025
ff5bff3
Switch to serif font: Descwl credible intervals
timwhite0 Oct 21, 2025
5c22302
Update figure and switch to serif: cosmoDC2 galaxies
timwhite0 Oct 21, 2025
c59baa7
Update figure: cosmoDC2 galaxies
timwhite0 Oct 21, 2025
99e2e98
Update gitignore
timwhite0 Oct 21, 2025
619185e
Posterior stdev maps for DC2 redbin3
timwhite0 Nov 24, 2025
3906fe4
Posterior mean metrics for DC2 redbin3
timwhite0 Nov 24, 2025
f05eea4
np array -> tensor instead of (list of np arrays) -> tensor
timwhite0 Dec 5, 2025
09ca5b8
Better train/val/test partition via file names
timwhite0 Dec 5, 2025
1843f66
Remove alt DC2 catalog generation files which kept all galaxies in ob…
timwhite0 Dec 5, 2025
750fe60
Generate DC2 catalog without flux threshold
timwhite0 Dec 6, 2025
9af6a58
Update configs after catalog changes
timwhite0 Dec 6, 2025
29bf66b
Manually specify res_midpoint when building network layers
timwhite0 Dec 6, 2025
7079fa5
Update plot axis scales
timwhite0 Dec 6, 2025
4877d26
Update path in great lakes config
timwhite0 Dec 6, 2025
38ffed5
Update train/val/test partition in dc2.py
timwhite0 Dec 7, 2025
3037f95
Remove old redbin1 notebooks
timwhite0 Dec 8, 2025
202d1cb
Update figure: DC2 image and maps
timwhite0 Dec 10, 2025
6e4befd
Update train/val/test partition in dc2.py again
timwhite0 Dec 10, 2025
1ed8754
Remove old configs
timwhite0 Dec 10, 2025
53c2934
Update current configs
timwhite0 Dec 10, 2025
9e48d91
Rename configs
timwhite0 Dec 10, 2025
a7ec51b
Update gitignore
timwhite0 Dec 10, 2025
92f46fe
Update configs
timwhite0 Dec 10, 2025
494766b
Update figures: DC2 credible intervals and maps
timwhite0 Dec 10, 2025
bf88ca8
Update figure: DC2 galaxy properties
timwhite0 Dec 10, 2025
912804a
Update DC2 image and maps
timwhite0 Dec 10, 2025
b175adf
Update learning rate
timwhite0 Dec 12, 2025
19be518
New ckpt
timwhite0 Dec 13, 2025
214ca39
Update posterior mean map colorbar
timwhite0 Dec 13, 2025
5746961
Update plot colors
timwhite0 Dec 14, 2025
24ad10b
Merge master into tw/weak_lensing
timwhite0 Dec 14, 2025
bf9ca20
Update READMEs
timwhite0 Dec 15, 2025
236bbf7
Make res_midpoint an argument
timwhite0 Dec 15, 2025
34bdc32
Update gitignore
timwhite0 Dec 15, 2025
8cc1564
Update README
timwhite0 Dec 15, 2025
13cfff4
Make weak lensing configs and encoder independent
timwhite0 Dec 15, 2025
e3a8e8f
Rename some encoder arguments
timwhite0 Dec 15, 2025
7c256f9
ruff formatting fixes
timwhite0 Dec 15, 2025
027089a
more ruff formatting fixes
timwhite0 Dec 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,5 @@ multirun
DC2_*.out
input_images/
WeakLensingResults*
ResultsDC2*
case_studies/weak_lensing/**/*.pt
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ Configuration files are located in `bliss/conf/` and can be composed and overrid

The `case_studies/` directory contains research applications of BLISS:

- **weak_lensing/** - Shear (γ) and convergence (κ) estimation, validated on DC2 and DECaLS simulations
- **weak_lensing/** - Tomographic shear (γ) and convergence (κ) mapping
- **redshift/** - Photo-z estimation (BLISS-PZ) with multiple variational families
- **galaxy_clustering/** - Galaxy cluster detection and membership prediction, validated on DES DR2
- **dc2_cataloging/** - Full cataloging pipeline for DC2 simulation
Expand Down
144 changes: 144 additions & 0 deletions ablation.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
Architectures considered for DC2 with four redshift bins, tile side length of 256 pixels

512 true true
input torch.Size([1, 6, 5, 2048, 2048])
preprocess torch.Size([1, 64, 2048, 2048])
0 torch.Size([1, 64, 1024, 1024])
1 torch.Size([1, 64, 1024, 1024])
2 torch.Size([1, 128, 512, 512])
3 torch.Size([1, 128, 512, 512])
4 torch.Size([1, 256, 256, 256])
5 torch.Size([1, 256, 256, 256])
6 torch.Size([1, 512, 128, 128])
7 torch.Size([1, 512, 128, 128])
8 torch.Size([1, 512, 64, 64])
9 torch.Size([1, 512, 32, 32])
10 torch.Size([1, 256, 16, 16])
11 torch.Size([1, 128, 8, 8])
final torch.Size([1, 8, 8, 24])



512 true false
input torch.Size([1, 6, 5, 2048, 2048])
preprocess torch.Size([1, 64, 2048, 2048])
0 torch.Size([1, 64, 1024, 1024])
1 torch.Size([1, 128, 512, 512])
2 torch.Size([1, 256, 256, 256])
3 torch.Size([1, 512, 128, 128])
4 torch.Size([1, 512, 64, 64])
5 torch.Size([1, 512, 32, 32])
6 torch.Size([1, 256, 16, 16])
7 torch.Size([1, 128, 8, 8])
final torch.Size([1, 8, 8, 24])



512 false true
input torch.Size([1, 6, 5, 2048, 2048])
preprocess torch.Size([1, 64, 2048, 2048])
0 torch.Size([1, 64, 2048, 2048])
1 torch.Size([1, 64, 2048, 2048])
2 torch.Size([1, 128, 1024, 1024])
3 torch.Size([1, 128, 1024, 1024])
4 torch.Size([1, 256, 512, 512])
5 torch.Size([1, 256, 512, 512])
6 torch.Size([1, 512, 256, 256])
7 torch.Size([1, 512, 256, 256])
8 torch.Size([1, 512, 128, 128])
9 torch.Size([1, 512, 128, 128])
10 torch.Size([1, 512, 64, 64])
11 torch.Size([1, 512, 32, 32])
12 torch.Size([1, 256, 16, 16])
13 torch.Size([1, 128, 8, 8])
final torch.Size([1, 8, 8, 24])



512 false false
input torch.Size([1, 6, 5, 2048, 2048])
preprocess torch.Size([1, 64, 2048, 2048])
0 torch.Size([1, 64, 2048, 2048])
1 torch.Size([1, 128, 1024, 1024])
2 torch.Size([1, 256, 512, 512])
3 torch.Size([1, 512, 256, 256])
4 torch.Size([1, 512, 128, 128])
5 torch.Size([1, 512, 64, 64])
6 torch.Size([1, 512, 32, 32])
7 torch.Size([1, 256, 16, 16])
8 torch.Size([1, 128, 8, 8])
final torch.Size([1, 8, 8, 24])



1024 true true
input torch.Size([1, 6, 5, 2048, 2048])
preprocess torch.Size([1, 64, 2048, 2048])
0 torch.Size([1, 64, 1024, 1024])
1 torch.Size([1, 64, 1024, 1024])
2 torch.Size([1, 128, 512, 512])
3 torch.Size([1, 128, 512, 512])
4 torch.Size([1, 256, 256, 256])
5 torch.Size([1, 256, 256, 256])
6 torch.Size([1, 512, 128, 128])
7 torch.Size([1, 512, 128, 128])
8 torch.Size([1, 1024, 128, 128])
9 torch.Size([1, 1024, 128, 128])
10 torch.Size([1, 1024, 64, 64])
11 torch.Size([1, 512, 32, 32])
12 torch.Size([1, 256, 16, 16])
13 torch.Size([1, 128, 8, 8])
final torch.Size([1, 8, 8, 24])



1024 true false
input torch.Size([1, 6, 5, 2048, 2048])
preprocess torch.Size([1, 64, 2048, 2048])
0 torch.Size([1, 64, 1024, 1024])
1 torch.Size([1, 128, 512, 512])
2 torch.Size([1, 256, 256, 256])
3 torch.Size([1, 512, 128, 128])
4 torch.Size([1, 1024, 128, 128])
5 torch.Size([1, 1024, 64, 64])
6 torch.Size([1, 512, 32, 32])
7 torch.Size([1, 256, 16, 16])
8 torch.Size([1, 128, 8, 8])
final torch.Size([1, 8, 8, 24])



1024 false true
input torch.Size([1, 6, 5, 2048, 2048])
preprocess torch.Size([1, 64, 2048, 2048])
0 torch.Size([1, 64, 2048, 2048])
1 torch.Size([1, 64, 2048, 2048])
2 torch.Size([1, 128, 1024, 1024])
3 torch.Size([1, 128, 1024, 1024])
4 torch.Size([1, 256, 512, 512])
5 torch.Size([1, 256, 512, 512])
6 torch.Size([1, 512, 256, 256])
7 torch.Size([1, 512, 256, 256])
8 torch.Size([1, 1024, 128, 128])
9 torch.Size([1, 1024, 128, 128])
10 torch.Size([1, 1024, 64, 64])
11 torch.Size([1, 512, 32, 32])
12 torch.Size([1, 256, 16, 16])
13 torch.Size([1, 128, 8, 8])
final torch.Size([1, 8, 8, 24])



1024 false false
input torch.Size([1, 6, 5, 2048, 2048])
preprocess torch.Size([1, 64, 2048, 2048])
0 torch.Size([1, 64, 2048, 2048])
1 torch.Size([1, 128, 1024, 1024])
2 torch.Size([1, 256, 512, 512])
3 torch.Size([1, 512, 256, 256])
4 torch.Size([1, 1024, 128, 128])
5 torch.Size([1, 1024, 64, 64])
6 torch.Size([1, 512, 32, 32])
7 torch.Size([1, 256, 16, 16])
8 torch.Size([1, 128, 8, 8])
final torch.Size([1, 8, 8, 24])
2 changes: 1 addition & 1 deletion bliss/surveys/dc2.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def get_bands_flux_and_psf(cls, bands, catalog, median=True):
else:
psf_params_cur_band.append(catalog[f"{i}{b}"].values.astype(np.float32))
psf_params_list.append(
torch.tensor(psf_params_cur_band)
torch.tensor(np.array(psf_params_cur_band))
) # bands x 4 (params per band) x n_obj

return torch.stack(flux_list).t(), torch.stack(psf_params_list).unsqueeze(0)
24 changes: 9 additions & 15 deletions case_studies/weak_lensing/README.md
Original file line number Diff line number Diff line change
@@ -1,25 +1,19 @@
### Neural posterior estimation of weak lensing shear and convergence
#### Tim White, Shreyas Chandrashekaran, Camille Avestruz, and Jeffrey Regier
#### with assistance from Dingrui Tao, Steve Fan, and Tahseen Younus
### Neural posterior estimation for tomographic field-level weak lensing inference
#### Tim White, Shreyas Chandrashekaran, Dingrui Tao, Camille Avestruz, and Jeffrey Regier
#### with assistance from Steve Fan and Tahseen Younus

This case study aims to estimate weak lensing shear and convergence for the DC2 simulated sky survey. See `notebooks/dc2/manuscript` for our most recent results.
In this case study, we use neural posterior estimation to infer tomographic shear and convergence maps from LSST-like images.

Some useful commands:

- Train `lensing_encoder` on DC2 images
To train the encoder on DC2 images, run

```
nohup bliss -cp /home/twhit/bliss/case_studies/weak_lensing/ -cn lensing_config_dc2.yaml mode=train &> train_on_dc2.out &
nohup bliss -cp <path>/bliss/case_studies/weak_lensing/dc2 -cn config_dc2.yaml mode=train &> train_on_dc2.out &
```

- Generate synthetic images with shear and convergence, as specified in `lensing_prior`
To train the encoder on descwl-shear-sims images, run

```
nohup bliss -cp /home/twhit/bliss/case_studies/weak_lensing/ -cn lensing_config_simulator.yaml mode=generate &> generate_synthetic.out &
nohup bliss -cp <path>/bliss/case_studies/weak_lensing/descwl -cn config_descwl.yaml mode=train &> train_on_descwl.out &
```

- Train `lensing_encoder` on synthetic images:

```
nohup bliss -cp /home/twhit/bliss/case_studies/weak_lensing/ -cn lensing_config_simulator.yaml mode=train &> train_on_synthetic.out &
```
See `dc2/notebooks` and `descwl/notebooks` for some exploratory plots and our most recent results.
11 changes: 5 additions & 6 deletions case_studies/weak_lensing/convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ def __init__(
self,
n_bands,
ch_per_band,
n_pixels_per_side,
n_tiles_per_side,
res_init,
res_midpoint,
res_final,
ch_init,
ch_max,
ch_final,
Expand All @@ -26,8 +27,6 @@ def __init__(
if n_var_params is not None:
ch_final = max(ch_final, 2 ** math.ceil(math.log2(n_var_params)))

res_midpoint = int(math.sqrt(n_pixels_per_side * n_tiles_per_side))

self.preprocess3d = nn.Sequential(
nn.Conv3d(n_bands, ch_init, [ch_per_band, 5, 5], padding=[0, 2, 2]),
nn.GroupNorm(num_groups=32, num_channels=ch_init),
Expand All @@ -38,9 +37,9 @@ def __init__(
ch_init,
ch_max,
ch_final,
n_pixels_per_side,
res_init,
res_midpoint,
n_tiles_per_side,
res_final,
initial_downsample,
more_up_layers,
num_bottleneck_layers,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
---
defaults:
- ../../../bliss/conf@_here_: base_config
- _self_
- override hydra/job_logging: stdout

mode: train
# completely disable hydra logging
# https://github.com/facebookresearch/hydra/issues/910
hydra:
output_subdir: null
run:
dir: .

paths:
dc2: /data/scratch/dc2local
Expand Down Expand Up @@ -46,27 +50,27 @@ lensing_plots:

encoder:
_target_: case_studies.weak_lensing.encoder.WeakLensingEncoder
survey_bands: ["u", "g", "r", "i", "z", "y"]
reference_band: 2 # r-band
tile_slen: 256
n_pixels_per_side: 2048
n_tiles_per_side: 8
n_bands: 6
res_init: 2048 # num pixels per side
res_midpoint: 128
res_final: 8 # num tiles per side
ch_init: 64
ch_max: 1024
ch_final: 128
initial_downsample: true
initial_downsample: false
more_up_layers: true
num_bottleneck_layers: 0
image_normalizers: ${lensing_normalizers_nanpsf}
var_dist:
_target_: bliss.encoder.variational_dist.VariationalDist
tile_slen: ${surveys.dc2.tile_slen}
factors: ${variational_factors}
optimizer_params:
lr: 1e-4
scheduler_params:
milestones: []
gamma: 1.0
image_normalizers: ${lensing_normalizers_nanpsf}
var_dist:
_target_: bliss.encoder.variational_dist.VariationalDist
tile_slen: ${encoder.tile_slen}
factors: ${variational_factors}
loss_plots_location: ${paths.output}/${train.trainer.logger.name}/${train.trainer.logger.version}/loss_plots
mode_metrics:
_target_: torchmetrics.MetricCollection
_convert_: partial
Expand All @@ -79,9 +83,6 @@ encoder:
_target_: torchmetrics.MetricCollection
_convert_: partial
metrics: ${lensing_plots}
use_double_detect: false
use_checkerboard: false
loss_plots_location: ${paths.output}/${train.trainer.logger.name}/${train.trainer.logger.version}/loss_plots

surveys:
dc2:
Expand All @@ -94,27 +95,50 @@ surveys:
splits: 0:80/80:90/90:100
avg_ellip_kernel_size: 15 # must be odd
avg_ellip_kernel_sigma: 15
redshift_quantiles: [0.00]
num_redshift_bins: 1 # length of redshift_quantiles
redshift_quantiles: [0.00, 0.762988, 1.120420, 1.592735]
num_redshift_bins: 4 # length of redshift_quantiles
batch_size: 1
num_workers: 1
cached_data_path: ${paths.dc2}/dc2_lensing_splits
train_transforms:
- _target_: case_studies.weak_lensing.data_augmentation.LensingRotateFlipTransform
shuffle_file_order: false # partition train/val/test by ra/dec so that we can compute 2PCFs on spatially contiguous test set

mode: train

train:
trainer:
_target_: pytorch_lightning.Trainer
logger:
name: WeakLensingResultsDC2
version: ${now:%Y-%m-%d_%H-%M}
_target_: pytorch_lightning.loggers.TensorBoardLogger
save_dir: ${paths.output}
name: ResultsDC2
version: nanpsf_${encoder.ch_max}_${encoder.initial_downsample}_${encoder.more_up_layers}
default_hp_metric: false
reload_dataloaders_every_n_epochs: 0
check_val_every_n_epoch: 1
log_every_n_steps: 10
max_epochs: 300
accelerator: gpu
devices: 1
use_distributed_sampler: false
precision: 32-true
callbacks:
checkpointing:
_target_: pytorch_lightning.callbacks.ModelCheckpoint
filename: "encoder_{epoch}"
save_top_k: 1
verbose: true
monitor: val/_loss
mode: min
save_on_train_epoch_end: false
early_stopping:
_target_: pytorch_lightning.callbacks.early_stopping.EarlyStopping
monitor: val/_loss
mode: min
patience: 100
data_source: ${surveys.dc2}
pretrained_weights: null
encoder: ${encoder}
seed: 123123
pretrained_weights: null
ckpt_path: null
matmul_precision: high
Loading