Skip to content

Commit 6c5340d

Browse files
Two speedups (opengeos#262)
* Two speedups - allow the sample window size to be set from samgeo.generate. Fewer, larger windows take less time as long as you have the memory. - Skip any sample windows which are 100% nodata, with the option to set lower thresholds. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5664743 commit 6c5340d

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

samgeo/common.py

+10
Original file line numberDiff line numberDiff line change
@@ -1122,6 +1122,8 @@ def tiff_to_tiff(
11221122
func,
11231123
data_to_rgb=chw_to_hwc,
11241124
sample_size=(512, 512),
1125+
sample_nodata_threshold=1.0,
1126+
nodata_value=None,
11251127
sample_resize=None,
11261128
bound=128,
11271129
foreground=True,
@@ -1132,6 +1134,9 @@ def tiff_to_tiff(
11321134
with rasterio.open(src_fp) as src:
11331135
profile = src.profile
11341136

1137+
if nodata_value is None:
1138+
nodata_values = profile.get("nodata", None)
1139+
11351140
# Computer blocks
11361141
rh, rw = profile["height"], profile["width"]
11371142
sh, sw = sample_size
@@ -1154,6 +1159,11 @@ def tiff_to_tiff(
11541159
for b in tqdm(sample_grid):
11551160
# Read each tile from the source
11561161
r = read_block(src, **b)
1162+
1163+
if nodata_value is not None:
1164+
if (r == nodata_value).mean() >= sample_nodata_threshold:
1165+
continue
1166+
11571167
# Extract the first 3 channels as RGB
11581168
uint8_rgb_in = data_to_rgb(r)
11591169
orig_size = uint8_rgb_in.shape[:2]

samgeo/samgeo.py

+12
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ def generate(
152152
output=None,
153153
foreground=True,
154154
batch=False,
155+
batch_sample_size=(512, 512),
156+
batch_nodata_threshold=1.0,
157+
nodata_value=None,
155158
erosion_kernel=None,
156159
mask_multiplier=255,
157160
unique=True,
@@ -164,6 +167,12 @@ def generate(
164167
output (str, optional): The path to the output image. Defaults to None.
165168
foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
166169
batch (bool, optional): Whether to generate masks for a batch of image tiles. Defaults to False.
170+
batch_sample_size (tuple, optional): When batch=True, the size of the sample window when iterating over rasters.
171+
batch_nodata_threshold (float,optional): Batch samples with a fraction of nodata pixels above this threshold will
172+
not be used to generate a mask. The default, 1.0, will skip samples with 100% nodata values. This is useful
173+
when rasters have large areas of nodata values which can be skipped.
174+
nodata_value (int, optional): Nodata value to use in checking batch_nodata_threshold. The default, None,
175+
will use the nodata value in the raster metadata if present.
167176
erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders.
168177
Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
169178
mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
@@ -190,6 +199,9 @@ def generate(
190199
output,
191200
self,
192201
foreground=foreground,
202+
sample_size=batch_sample_size,
203+
sample_nodata_threshold=batch_nodata_threshold,
204+
nodata_value=nodata_value,
193205
erosion_kernel=erosion_kernel,
194206
mask_multiplier=mask_multiplier,
195207
**kwargs,

0 commit comments

Comments
 (0)