Skip to content

Commit cfbc6d0

Browse files
authored
Add support for more than 512 channels in horizontal flip (#49)
* Cleanup * Support for more than 512 channels in flip * Support for more than 512 channels in flip
1 parent f6e5bfc commit cfbc6d0

File tree

5 files changed

+187
-63
lines changed

5 files changed

+187
-63
lines changed

.pre-commit-config.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ repos:
4444
- id: python-use-type-annotations
4545
- id: text-unicode-replacement-char
4646
- repo: https://github.com/codespell-project/codespell
47-
rev: v2.3.0
47+
rev: v2.4.1
4848
hooks:
4949
- id: codespell
5050
additional_dependencies: ["tomli"]
@@ -53,13 +53,13 @@ repos:
5353
# hooks:
5454
# - id: markdownlint
5555
- repo: https://github.com/tox-dev/pyproject-fmt
56-
rev: "v2.5.0"
56+
rev: "v2.5.1"
5757
hooks:
5858
- id: pyproject-fmt
5959
additional_dependencies: ["tomli"]
6060
- repo: https://github.com/astral-sh/ruff-pre-commit
6161
# Ruff version.
62-
rev: v0.8.4
62+
rev: v0.9.10
6363
hooks:
6464
# Run the linter.
6565
- id: ruff
@@ -68,7 +68,7 @@ repos:
6868
# Run the formatter.
6969
- id: ruff-format
7070
- repo: https://github.com/pre-commit/mirrors-mypy
71-
rev: v1.14.0
71+
rev: v1.15.0
7272
hooks:
7373
- id: mypy
7474
files: ^(albucore|benchmark)/

albucore/functions.py

+92-35
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ def add_weighted_simsimd(img1: np.ndarray, weight1: float, img2: np.ndarray, wei
3232
original_dtype = img1.dtype
3333

3434
if img2.dtype != original_dtype:
35-
img2 = clip(img2.astype(original_dtype), original_dtype, inplace=True)
35+
img2 = clip(img2.astype(original_dtype, copy=False), original_dtype, inplace=True)
3636

3737
return np.frombuffer(
38-
ss.wsum(img1.reshape(-1), img2.astype(original_dtype).reshape(-1), alpha=weight1, beta=weight2),
38+
ss.wsum(img1.reshape(-1), img2.astype(original_dtype, copy=False).reshape(-1), alpha=weight1, beta=weight2),
3939
dtype=original_dtype,
4040
).reshape(
4141
original_shape,
@@ -51,7 +51,7 @@ def multiply_by_constant_simsimd(img: np.ndarray, value: float) -> np.ndarray:
5151

5252

5353
def add_constant_simsimd(img: np.ndarray, value: float) -> np.ndarray:
54-
return add_weighted_simsimd(img, 1, (np.ones_like(img) * value).astype(img.dtype), 1)
54+
return add_weighted_simsimd(img, 1, (np.ones_like(img) * value).astype(img.dtype, copy=False), 1)
5555

5656

5757
def create_lut_array(
@@ -92,11 +92,12 @@ def apply_lut(
9292

9393
if isinstance(value, (int, float)):
9494
lut = create_lut_array(dtype, value, operation)
95-
return sz_lut(img, clip(lut, dtype), inplace)
95+
return sz_lut(img, clip(lut, dtype, inplace=False), False)
9696

9797
num_channels = img.shape[-1]
98-
luts = create_lut_array(dtype, value, operation)
99-
return cv2.merge([sz_lut(img[:, :, i], clip(luts[i], dtype, inplace=False), inplace) for i in range(num_channels)])
98+
99+
luts = clip(create_lut_array(dtype, value, operation), dtype, inplace=False)
100+
return cv2.merge([sz_lut(img[:, :, i], luts[i], inplace) for i in range(num_channels)])
100101

101102

102103
def prepare_value_opencv(
@@ -135,14 +136,14 @@ def _prepare_array_value(
135136
operation: Literal["add", "multiply"],
136137
) -> np.ndarray:
137138
if value.dtype == np.float64:
138-
value = value.astype(np.float32)
139+
value = value.astype(np.float32, copy=False)
139140
if value.ndim == 1:
140141
value = value.reshape(1, 1, -1)
141142
value = np.broadcast_to(value, img.shape)
142143
if operation == "add" and img.dtype == np.uint8:
143144
if np.all(value >= 0):
144-
return clip(value, np.uint8)
145-
return np.trunc(value).astype(np.float32)
145+
return clip(value, np.uint8, inplace=False)
146+
return np.trunc(value).astype(np.float32, copy=False)
146147
return value
147148

148149

@@ -154,7 +155,7 @@ def apply_numpy(
154155
if operation == "add" and img.dtype == np.uint8:
155156
value = np.int16(value)
156157

157-
return np_operations[operation](img.astype(np.float32), value)
158+
return np_operations[operation](img.astype(np.float32, copy=False), value)
158159

159160

160161
def multiply_lut(img: np.ndarray, value: np.ndarray | float, inplace: bool) -> np.ndarray:
@@ -165,7 +166,7 @@ def multiply_lut(img: np.ndarray, value: np.ndarray | float, inplace: bool) -> n
165166
def multiply_opencv(img: np.ndarray, value: np.ndarray | float) -> np.ndarray:
166167
value = prepare_value_opencv(img, value, "multiply")
167168
if img.dtype == np.uint8:
168-
return cv2.multiply(img.astype(np.float32), value)
169+
return cv2.multiply(img.astype(np.float32, copy=False), value)
169170
return cv2.multiply(img, value)
170171

171172

@@ -223,7 +224,10 @@ def add_opencv(img: np.ndarray, value: np.ndarray | float, inplace: bool = False
223224
)
224225

225226
if needs_float:
226-
return cv2.add(img.astype(np.float32), value if isinstance(value, (int, float)) else value.astype(np.float32))
227+
return cv2.add(
228+
img.astype(np.float32, copy=False),
229+
value if isinstance(value, (int, float)) else value.astype(np.float32, copy=False),
230+
)
227231

228232
# Use img as the destination array if inplace=True
229233
dst = img if inplace else None
@@ -272,14 +276,14 @@ def add(img: np.ndarray, value: ValueType, inplace: bool = False) -> np.ndarray:
272276

273277

274278
def normalize_numpy(img: np.ndarray, mean: float | np.ndarray, denominator: float | np.ndarray) -> np.ndarray:
275-
img = img.astype(np.float32)
279+
img = img.astype(np.float32, copy=False)
276280
img -= mean
277281
return img * denominator
278282

279283

280284
@preserve_channel_dim
281285
def normalize_opencv(img: np.ndarray, mean: float | np.ndarray, denominator: float | np.ndarray) -> np.ndarray:
282-
img = img.astype(np.float32)
286+
img = img.astype(np.float32, copy=False)
283287
mean_img = np.zeros_like(img, dtype=np.float32)
284288
denominator_img = np.zeros_like(img, dtype=np.float32)
285289

@@ -290,7 +294,7 @@ def normalize_opencv(img: np.ndarray, mean: float | np.ndarray, denominator: flo
290294
denominator = np.full(img.shape, denominator, dtype=np.float32)
291295

292296
# Ensure the shapes match for broadcasting
293-
mean_img = (mean_img + mean).astype(np.float32)
297+
mean_img = (mean_img + mean).astype(np.float32, copy=False)
294298
denominator_img = denominator_img + denominator
295299

296300
result = cv2.subtract(img, mean_img)
@@ -343,7 +347,7 @@ def power_opencv(img: np.ndarray, value: float) -> np.ndarray:
343347
return cv2.pow(img, value)
344348
if img.dtype == np.uint8 and isinstance(value, float):
345349
# For uint8 images, convert to float32, apply power, then convert back to uint8
346-
img_float = img.astype(np.float32)
350+
img_float = img.astype(np.float32, copy=False)
347351
return cv2.pow(img_float, value)
348352

349353
raise ValueError(f"Unsupported image type {img.dtype} for power operation with value {value}")
@@ -368,7 +372,7 @@ def power(img: np.ndarray, exponent: ValueType, inplace: bool = False) -> np.nda
368372

369373

370374
def add_weighted_numpy(img1: np.ndarray, weight1: float, img2: np.ndarray, weight2: float) -> np.ndarray:
371-
return img1.astype(np.float32) * weight1 + img2.astype(np.float32) * weight2
375+
return img1.astype(np.float32, copy=False) * weight1 + img2.astype(np.float32, copy=False) * weight2
372376

373377

374378
@preserve_channel_dim
@@ -430,7 +434,7 @@ def multiply_add_opencv(img: np.ndarray, factor: ValueType, value: ValueType) ->
430434
if isinstance(value, (int, float)) and value == 0 and isinstance(factor, (int, float)) and factor == 0:
431435
return np.zeros_like(img)
432436

433-
result = img.astype(np.float32)
437+
result = img.astype(np.float32, copy=False)
434438
result = (
435439
cv2.multiply(result, np.ones_like(result) * factor, dtype=cv2.CV_64F)
436440
if factor != 0
@@ -473,7 +477,7 @@ def multiply_add(img: np.ndarray, factor: ValueType, value: ValueType, inplace:
473477

474478
@preserve_channel_dim
475479
def normalize_per_image_opencv(img: np.ndarray, normalization: NormalizationType) -> np.ndarray:
476-
img = img.astype(np.float32)
480+
img = img.astype(np.float32, copy=False)
477481
eps = 1e-4
478482

479483
if img.ndim == MONO_CHANNEL_DIMENSIONS:
@@ -486,7 +490,7 @@ def normalize_per_image_opencv(img: np.ndarray, normalization: NormalizationType
486490
mean = np.full_like(img, mean)
487491
std = np.full_like(img, std)
488492
normalized_img = cv2.divide(cv2.subtract(img, mean), std)
489-
return normalized_img.clip(-20, 20)
493+
return np.clip(normalized_img, -20, 20, out=normalized_img)
490494

491495
if normalization == "image_per_channel":
492496
mean, std = cv2.meanStdDev(img)
@@ -498,7 +502,7 @@ def normalize_per_image_opencv(img: np.ndarray, normalization: NormalizationType
498502
std = np.full_like(img, std)
499503

500504
normalized_img = cv2.divide(cv2.subtract(img, mean), std, dtype=cv2.CV_32F)
501-
return normalized_img.clip(-20, 20)
505+
return np.clip(normalized_img, -20, 20, out=normalized_img)
502506

503507
if normalization == "min_max" or (img.shape[-1] == 1 and normalization == "min_max_per_channel"):
504508
img_min = img.min()
@@ -513,14 +517,19 @@ def normalize_per_image_opencv(img: np.ndarray, normalization: NormalizationType
513517
img_min = np.full_like(img, img_min)
514518
img_max = np.full_like(img, img_max)
515519

516-
return cv2.divide(cv2.subtract(img, img_min), (img_max - img_min + eps), dtype=cv2.CV_32F).clip(-20, 20)
520+
return np.clip(
521+
cv2.divide(cv2.subtract(img, img_min), (img_max - img_min + eps), dtype=cv2.CV_32F),
522+
-20,
523+
20,
524+
out=img,
525+
)
517526

518527
raise ValueError(f"Unknown normalization method: {normalization}")
519528

520529

521530
@preserve_channel_dim
522531
def normalize_per_image_numpy(img: np.ndarray, normalization: NormalizationType) -> np.ndarray:
523-
img = img.astype(np.float32)
532+
img = img.astype(np.float32, copy=False)
524533
eps = 1e-4
525534

526535
if img.ndim == MONO_CHANNEL_DIMENSIONS:
@@ -530,23 +539,23 @@ def normalize_per_image_numpy(img: np.ndarray, normalization: NormalizationType)
530539
mean = img.mean()
531540
std = img.std() + eps
532541
normalized_img = (img - mean) / std
533-
return normalized_img.clip(-20, 20)
542+
return np.clip(normalized_img, -20, 20, out=normalized_img)
534543

535544
if normalization == "image_per_channel":
536545
pixel_mean = img.mean(axis=(0, 1))
537546
pixel_std = img.std(axis=(0, 1)) + eps
538547
normalized_img = (img - pixel_mean) / pixel_std
539-
return normalized_img.clip(-20, 20)
548+
return np.clip(normalized_img, -20, 20, out=normalized_img)
540549

541550
if normalization == "min_max":
542551
img_min = img.min()
543552
img_max = img.max()
544-
return (img - img_min) / (img_max - img_min + eps)
553+
return np.clip((img - img_min) / (img_max - img_min + eps), -20, 20, out=img)
545554

546555
if normalization == "min_max_per_channel":
547556
img_min = img.min(axis=(0, 1))
548557
img_max = img.max(axis=(0, 1))
549-
return (img - img_min) / (img_max - img_min + eps)
558+
return np.clip((img - img_min) / (img_max - img_min + eps), -20, 20, out=img)
550559

551560
raise ValueError(f"Unknown normalization method: {normalization}")
552561

@@ -579,7 +588,7 @@ def normalize_per_image_lut(img: np.ndarray, normalization: NormalizationType) -
579588
img_min = img.min()
580589
img_max = img.max()
581590
lut = (np.arange(0, max_value + 1, dtype=np.float32) - img_min) / (img_max - img_min + eps)
582-
return cv2.LUT(img, lut)
591+
return cv2.LUT(img, lut).clip(-20, 20)
583592

584593
if normalization == "min_max_per_channel":
585594
img_min = img.min(axis=(0, 1))
@@ -604,15 +613,15 @@ def normalize_per_image(img: np.ndarray, normalization: NormalizationType) -> np
604613
def to_float_numpy(img: np.ndarray, max_value: float | None = None) -> np.ndarray:
605614
if max_value is None:
606615
max_value = get_max_value(img.dtype)
607-
return (img / max_value).astype(np.float32)
616+
return (img / max_value).astype(np.float32, copy=False)
608617

609618

610619
@preserve_channel_dim
611620
def to_float_opencv(img: np.ndarray, max_value: float | None = None) -> np.ndarray:
612621
if max_value is None:
613622
max_value = get_max_value(img.dtype)
614623

615-
img_float = img.astype(np.float32)
624+
img_float = img.astype(np.float32, copy=False)
616625

617626
num_channels = get_num_channels(img)
618627

@@ -638,7 +647,7 @@ def to_float_lut(img: np.ndarray, max_value: float | None = None) -> np.ndarray:
638647

639648
def to_float(img: np.ndarray, max_value: float | None = None) -> np.ndarray:
640649
if img.dtype == np.float64:
641-
return img.astype(np.float32)
650+
return img.astype(np.float32, copy=False)
642651
if img.dtype == np.float32:
643652
return img
644653
if img.dtype == np.uint8:
@@ -657,17 +666,17 @@ def from_float_opencv(img: np.ndarray, target_dtype: np.dtype, max_value: float
657666
if max_value is None:
658667
max_value = get_max_value(target_dtype)
659668

660-
img_float = img.astype(np.float32)
669+
img_float = img.astype(np.float32, copy=False)
661670

662671
num_channels = get_num_channels(img)
663672

664673
if num_channels > MAX_OPENCV_WORKING_CHANNELS:
665674
# For images with more than 4 channels, create a full-sized multiplier
666675
max_value_array = np.full_like(img_float, max_value)
667-
return clip(np.rint(cv2.multiply(img_float, max_value_array)), target_dtype)
676+
return clip(np.rint(cv2.multiply(img_float, max_value_array)), target_dtype, inplace=False)
668677

669678
# For images with 4 or fewer channels, use scalar multiplication
670-
return clip(np.rint(img * max_value), target_dtype)
679+
return clip(np.rint(img * max_value), target_dtype, inplace=False)
671680

672681

673682
def from_float(img: np.ndarray, target_dtype: np.dtype, max_value: float | None = None) -> np.ndarray:
@@ -695,7 +704,7 @@ def from_float(img: np.ndarray, target_dtype: np.dtype, max_value: float | None
695704
return img
696705

697706
if target_dtype == np.float64:
698-
return img.astype(np.float32)
707+
return img.astype(np.float32, copy=False)
699708

700709
if img.dtype == np.float32:
701710
return from_float_opencv(img, target_dtype, max_value)
@@ -710,6 +719,9 @@ def hflip_numpy(img: np.ndarray) -> np.ndarray:
710719

711720
@preserve_channel_dim
712721
def hflip_cv2(img: np.ndarray) -> np.ndarray:
722+
# OpenCV's flip function has a limitation of 512 channels
723+
if img.ndim > 2 and img.shape[2] > 512:
724+
return _flip_multichannel(img, flip_code=1)
713725
return cv2.flip(img, 1)
714726

715727

@@ -719,6 +731,9 @@ def hflip(img: np.ndarray) -> np.ndarray:
719731

720732
@preserve_channel_dim
721733
def vflip_cv2(img: np.ndarray) -> np.ndarray:
734+
# OpenCV's flip function has a limitation of 512 channels
735+
if img.ndim > 2 and img.shape[2] > 512:
736+
return _flip_multichannel(img, flip_code=0)
722737
return cv2.flip(img, 0)
723738

724739

@@ -731,6 +746,48 @@ def vflip(img: np.ndarray) -> np.ndarray:
731746
return vflip_cv2(img)
732747

733748

749+
def _flip_multichannel(img: np.ndarray, flip_code: int) -> np.ndarray:
750+
"""Process images with more than 512 channels by splitting into chunks.
751+
752+
OpenCV's flip function has a limitation where it can only handle images with up to 512 channels.
753+
This function works around that limitation by splitting the image into chunks of 512 channels,
754+
flipping each chunk separately, and then concatenating the results.
755+
756+
Args:
757+
img: Input image with many channels
758+
flip_code: OpenCV flip code (0 for vertical, 1 for horizontal, -1 for both)
759+
760+
Returns:
761+
Flipped image with all channels preserved
762+
"""
763+
# Get image dimensions
764+
height, width = img.shape[:2]
765+
num_channels = 1 if img.ndim == 2 else img.shape[2]
766+
767+
# If the image has 2 dimensions or fewer than 512 channels, use cv2.flip directly
768+
if img.ndim == 2 or num_channels <= 512:
769+
return cv2.flip(img, flip_code)
770+
771+
# Process in chunks of 512 channels
772+
chunk_size = 512
773+
result_chunks = []
774+
775+
for i in range(0, num_channels, chunk_size):
776+
end_idx = min(i + chunk_size, num_channels)
777+
chunk = img[:, :, i:end_idx]
778+
flipped_chunk = cv2.flip(chunk, flip_code)
779+
780+
# Ensure the chunk maintains its dimensionality
781+
# This is needed when the last chunk has only one channel and cv2.flip reduces the dimensions
782+
if flipped_chunk.ndim == 2 and img.ndim == 3:
783+
flipped_chunk = np.expand_dims(flipped_chunk, axis=2)
784+
785+
result_chunks.append(flipped_chunk)
786+
787+
# Concatenate the chunks along the channel dimension
788+
return np.concatenate(result_chunks, axis=2)
789+
790+
734791
def float32_io(func: Callable[..., np.ndarray]) -> Callable[..., np.ndarray]:
735792
"""Decorator to ensure float32 input/output for image processing functions.
736793

0 commit comments

Comments
 (0)