Skip to content

Commit 3363e0a

Browse files
authored
Cleaned up input to TextImage (#2387)
1 parent d6df48d commit 3363e0a

File tree

4 files changed

+46
-41
lines changed

4 files changed

+46
-41
lines changed

albumentations/augmentations/text/functional.py

+27-22
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import random
4-
from collections.abc import Sequence
54
from typing import TYPE_CHECKING, Any
65

76
import cv2
@@ -75,7 +74,7 @@ def convert_image_to_pil(image: np.ndarray) -> Image:
7574

7675

7776
def draw_text_on_pil_image(pil_image: Image, metadata_list: list[dict[str, Any]]) -> Image:
78-
"""Draw text on a PIL image using metadata information."""
77+
"""Draw text on a PIL image."""
7978
try:
8079
from PIL import ImageDraw
8180
except ImportError:
@@ -87,10 +86,20 @@ def draw_text_on_pil_image(pil_image: Image, metadata_list: list[dict[str, Any]]
8786
text = metadata["text"]
8887
font = metadata["font"]
8988
font_color = metadata["font_color"]
90-
if isinstance(font_color, (list, tuple)):
89+
90+
# Adapt font_color based on image mode
91+
if pil_image.mode == "L": # Grayscale
92+
# For grayscale images, use only the first value or average the RGB values
93+
if isinstance(font_color, tuple):
94+
if len(font_color) >= 3:
95+
# Average RGB values for grayscale
96+
font_color = int(sum(font_color[:3]) / 3)
97+
elif len(font_color) == 1:
98+
font_color = int(font_color[0])
99+
# For RGB and other modes, ensure font_color is a tuple of integers
100+
elif isinstance(font_color, tuple):
91101
font_color = tuple(int(c) for c in font_color)
92-
elif isinstance(font_color, float):
93-
font_color = int(font_color)
102+
94103
position = bbox_coords[:2]
95104
draw.text(position, text, font=font, fill=font_color)
96105
return pil_image
@@ -112,27 +121,23 @@ def draw_text_on_multi_channel_image(image: np.ndarray, metadata_list: list[dict
112121
font = metadata["font"]
113122
font_color = metadata["font_color"]
114123

115-
# Handle different font_color types
116-
if isinstance(font_color, str):
117-
# If it's a string, use it as is for all channels
118-
font_color = [font_color] * image.shape[2]
119-
elif isinstance(font_color, (int, float)):
120-
# If it's a single number, convert to int and use for all channels
121-
font_color = [int(font_color)] * image.shape[2]
122-
elif isinstance(font_color, Sequence):
123-
# If it's a sequence, ensure it has the right length and convert to int
124-
if len(font_color) != image.shape[2]:
125-
raise ValueError(
126-
f"font_color sequence length ({len(font_color)}) "
127-
f"must match the number of image channels ({image.shape[2]})",
128-
)
129-
font_color = [int(c) for c in font_color]
130-
else:
131-
raise TypeError(f"Unsupported font_color type: {type(font_color)}")
124+
# Handle font_color as tuple[float, ...]
125+
# Ensure we have enough color values for all channels
126+
if len(font_color) < image.shape[2]:
127+
# If fewer values than channels, pad with zeros
128+
font_color = tuple(list(font_color) + [0] * (image.shape[2] - len(font_color)))
129+
elif len(font_color) > image.shape[2]:
130+
# If more values than channels, truncate
131+
font_color = font_color[: image.shape[2]]
132+
133+
# Convert to integers for PIL
134+
font_color = [int(c) for c in font_color]
132135

133136
position = bbox_coords[:2]
134137

138+
# For each channel, use the corresponding color value
135139
for channel_id, pil_image in enumerate(pil_images):
140+
# For single-channel PIL images, color must be an integer
136141
pil_image.text(position, text, font=font, fill=font_color[channel_id])
137142

138143
return np.stack([np.array(channel) for channel in channels], axis=2)

albumentations/augmentations/text/transforms.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@ class TextImage(ImageOnlyTransform):
2525
Args:
2626
font_path (str | Path): Path to the font file to use for rendering text.
2727
stopwords (list[str] | None): List of stopwords for text augmentation.
28-
augmentations (tuple[str | None, ...] | list[str | None]): List of text augmentations to apply.
28+
augmentations (tuple[str | None, ...]): List of text augmentations to apply.
2929
None: text is printed as is
3030
"insertion": insert random stop words into the text.
3131
"swap": swap random words in the text.
3232
"deletion": delete random words from the text.
3333
fraction_range (tuple[float, float]): Range for selecting a fraction of bounding boxes to modify.
3434
font_size_fraction_range (tuple[float, float]): Range for selecting the font size as a fraction of
3535
bounding box height.
36-
font_color (list[str] | str): List of possible font colors or a single font color.
36+
font_color (tuple[float, ...]): Font color as RGB values (e.g., (0, 0, 0) for black).
3737
clear_bg (bool): Whether to clear the background before rendering text.
3838
metadata_key (str): Key to access metadata in the parameters.
3939
p (float): Probability of applying the transform.
@@ -52,11 +52,11 @@ class TextImage(ImageOnlyTransform):
5252
>>> transform = A.Compose([
5353
A.TextImage(
5454
font_path=Path("/path/to/font.ttf"),
55-
stopwords=["the", "is", "in"],
55+
stopwords=("the", "is", "in"),
5656
augmentations=("insertion", "deletion"),
5757
fraction_range=(0.5, 1.0),
5858
font_size_fraction_range=(0.5, 0.9),
59-
font_color=["red", "green", "blue"],
59+
font_color=(255, 0, 0), # red in RGB
6060
metadata_key="text_metadata",
6161
p=0.5
6262
)
@@ -69,7 +69,7 @@ class TextImage(ImageOnlyTransform):
6969
class InitSchema(BaseTransformInitSchema):
7070
font_path: str | Path
7171
stopwords: tuple[str, ...]
72-
augmentations: tuple[str | None, ...] | list[str | None]
72+
augmentations: tuple[str | None, ...]
7373
fraction_range: Annotated[
7474
tuple[float, float],
7575
AfterValidator(nondecreasing),
@@ -80,18 +80,18 @@ class InitSchema(BaseTransformInitSchema):
8080
AfterValidator(nondecreasing),
8181
AfterValidator(check_range_bounds(0, 1)),
8282
]
83-
font_color: list[tuple[float, ...] | float | str] | tuple[float, ...] | float | str
83+
font_color: tuple[float, ...]
8484
clear_bg: bool
8585
metadata_key: str
8686

8787
def __init__(
8888
self,
8989
font_path: str | Path,
9090
stopwords: tuple[str, ...] = ("the", "is", "in", "at", "of"),
91-
augmentations: tuple[Literal["insertion", "swap", "deletion"] | None] = (None,),
91+
augmentations: tuple[Literal["insertion", "swap", "deletion"] | None, ...] = (None,),
9292
fraction_range: tuple[float, float] = (1.0, 1.0),
9393
font_size_fraction_range: tuple[float, float] = (0.8, 0.9),
94-
font_color: list[tuple[float, ...] | float | str] | tuple[float, ...] | float | str = "black",
94+
font_color: tuple[float, ...] = (0, 0, 0), # black in RGB
9595
clear_bg: bool = False,
9696
metadata_key: str = "textimage_metadata",
9797
p: float = 0.5,
@@ -174,7 +174,7 @@ def preprocess_metadata(
174174

175175
augmented_text = text if augmentation is None else self.random_aug(text, 0.5, choice=augmentation)
176176

177-
font_color = self.py_random.choice(self.font_color) if isinstance(self.font_color, list) else self.font_color
177+
font_color = self.font_color
178178

179179
return {
180180
"bbox_coords": (x_min, y_min, x_max, y_max),

tests/aug_definitions.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -389,8 +389,8 @@
389389
dict(
390390
font_path="./tests/files/LiberationSerif-Bold.ttf",
391391
font_size_fraction_range=(0.8, 0.9),
392-
font_color="red",
393-
stopwords=[
392+
font_color=(255, 0, 0), # red in RGB
393+
stopwords=(
394394
"a",
395395
"the",
396396
"is",
@@ -404,7 +404,7 @@
404404
"for",
405405
"at",
406406
"by",
407-
],
407+
),
408408
),
409409
],
410410
[A.GridElasticDeform, {"num_grid_xy": (10, 10), "magnitude": 10}],

tests/test_text.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def test_convert_image_to_pil(image_shape):
128128
"bbox_coords": (10, 10, 100, 50),
129129
"text": "Test",
130130
"font": font,
131-
"font_color": 127, # Red color
131+
"font_color": (127,), # Grayscale color
132132
},
133133
],
134134
), # Grayscale image
@@ -139,7 +139,7 @@ def test_convert_image_to_pil(image_shape):
139139
"bbox_coords": (10, 10, 100, 50),
140140
"text": "Test",
141141
"font": font,
142-
"font_color": 127, # Red color
142+
"font_color": (127,), # Single channel color
143143
},
144144
],
145145
), # Single channel image
@@ -150,24 +150,24 @@ def test_convert_image_to_pil(image_shape):
150150
"bbox_coords": (10, 10, 100, 50),
151151
"text": "Test",
152152
"font": font,
153-
"font_color": (127, 127, 127), # Red color
153+
"font_color": (127, 127, 127), # RGB color
154154
},
155155
{
156156
"bbox_coords": (20, 20, 110, 60),
157157
"text": "Test",
158158
"font": font,
159-
"font_color": "red", # Red color
159+
"font_color": (255, 0, 0), # Red color
160160
},
161161
],
162-
), # RGB image with tuple and string font color
162+
), # RGB image with tuple colors
163163
(
164164
(100, 100, 5),
165165
[
166166
{
167167
"bbox_coords": (20, 20, 110, 60),
168168
"text": "Test",
169169
"font": font,
170-
"font_color": (127, 127, 127, 127, 127),
170+
"font_color": (127, 127, 127, 127, 127), # 5-channel color
171171
},
172172
],
173173
),
@@ -178,7 +178,7 @@ def test_convert_image_to_pil(image_shape):
178178
"bbox_coords": (20, 20, 110, 60),
179179
"text": "Test",
180180
"font": font,
181-
"font_color": "red",
181+
"font_color": (255, 0, 0), # RGB color (will be padded for 5 channels)
182182
},
183183
],
184184
),

0 commit comments

Comments
 (0)