Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleaned up input to TextImage #2387

Merged
merged 1 commit into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 27 additions & 22 deletions albumentations/augmentations/text/functional.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import random
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any

import cv2
Expand Down Expand Up @@ -75,7 +74,7 @@ def convert_image_to_pil(image: np.ndarray) -> Image:


def draw_text_on_pil_image(pil_image: Image, metadata_list: list[dict[str, Any]]) -> Image:
"""Draw text on a PIL image using metadata information."""
"""Draw text on a PIL image."""
try:
from PIL import ImageDraw
except ImportError:
Expand All @@ -87,10 +86,20 @@ def draw_text_on_pil_image(pil_image: Image, metadata_list: list[dict[str, Any]]
text = metadata["text"]
font = metadata["font"]
font_color = metadata["font_color"]
if isinstance(font_color, (list, tuple)):

# Adapt font_color based on image mode
if pil_image.mode == "L": # Grayscale
# For grayscale images, use only the first value or average the RGB values
if isinstance(font_color, tuple):
if len(font_color) >= 3:
# Average RGB values for grayscale
font_color = int(sum(font_color[:3]) / 3)
elif len(font_color) == 1:
font_color = int(font_color[0])
# For RGB and other modes, ensure font_color is a tuple of integers
elif isinstance(font_color, tuple):
font_color = tuple(int(c) for c in font_color)
elif isinstance(font_color, float):
font_color = int(font_color)

position = bbox_coords[:2]
draw.text(position, text, font=font, fill=font_color)
return pil_image
Expand All @@ -112,27 +121,23 @@ def draw_text_on_multi_channel_image(image: np.ndarray, metadata_list: list[dict
font = metadata["font"]
font_color = metadata["font_color"]

# Handle different font_color types
if isinstance(font_color, str):
# If it's a string, use it as is for all channels
font_color = [font_color] * image.shape[2]
elif isinstance(font_color, (int, float)):
# If it's a single number, convert to int and use for all channels
font_color = [int(font_color)] * image.shape[2]
elif isinstance(font_color, Sequence):
# If it's a sequence, ensure it has the right length and convert to int
if len(font_color) != image.shape[2]:
raise ValueError(
f"font_color sequence length ({len(font_color)}) "
f"must match the number of image channels ({image.shape[2]})",
)
font_color = [int(c) for c in font_color]
else:
raise TypeError(f"Unsupported font_color type: {type(font_color)}")
# Handle font_color as tuple[float, ...]
# Ensure we have enough color values for all channels
if len(font_color) < image.shape[2]:
# If fewer values than channels, pad with zeros
font_color = tuple(list(font_color) + [0] * (image.shape[2] - len(font_color)))
elif len(font_color) > image.shape[2]:
# If more values than channels, truncate
font_color = font_color[: image.shape[2]]

# Convert to integers for PIL
font_color = [int(c) for c in font_color]

position = bbox_coords[:2]

# For each channel, use the corresponding color value
for channel_id, pil_image in enumerate(pil_images):
# For single-channel PIL images, color must be an integer
pil_image.text(position, text, font=font, fill=font_color[channel_id])

return np.stack([np.array(channel) for channel in channels], axis=2)
Expand Down
18 changes: 9 additions & 9 deletions albumentations/augmentations/text/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ class TextImage(ImageOnlyTransform):
Args:
font_path (str | Path): Path to the font file to use for rendering text.
stopwords (list[str] | None): List of stopwords for text augmentation.
augmentations (tuple[str | None, ...] | list[str | None]): List of text augmentations to apply.
augmentations (tuple[str | None, ...]): List of text augmentations to apply.
None: text is printed as is
"insertion": insert random stop words into the text.
"swap": swap random words in the text.
"deletion": delete random words from the text.
fraction_range (tuple[float, float]): Range for selecting a fraction of bounding boxes to modify.
font_size_fraction_range (tuple[float, float]): Range for selecting the font size as a fraction of
bounding box height.
font_color (list[str] | str): List of possible font colors or a single font color.
font_color (tuple[float, ...]): Font color as RGB values (e.g., (0, 0, 0) for black).
clear_bg (bool): Whether to clear the background before rendering text.
metadata_key (str): Key to access metadata in the parameters.
p (float): Probability of applying the transform.
Expand All @@ -52,11 +52,11 @@ class TextImage(ImageOnlyTransform):
>>> transform = A.Compose([
A.TextImage(
font_path=Path("/path/to/font.ttf"),
stopwords=["the", "is", "in"],
stopwords=("the", "is", "in"),
augmentations=("insertion", "deletion"),
fraction_range=(0.5, 1.0),
font_size_fraction_range=(0.5, 0.9),
font_color=["red", "green", "blue"],
font_color=(255, 0, 0), # red in RGB
metadata_key="text_metadata",
p=0.5
)
Expand All @@ -69,7 +69,7 @@ class TextImage(ImageOnlyTransform):
class InitSchema(BaseTransformInitSchema):
font_path: str | Path
stopwords: tuple[str, ...]
augmentations: tuple[str | None, ...] | list[str | None]
augmentations: tuple[str | None, ...]
fraction_range: Annotated[
tuple[float, float],
AfterValidator(nondecreasing),
Expand All @@ -80,18 +80,18 @@ class InitSchema(BaseTransformInitSchema):
AfterValidator(nondecreasing),
AfterValidator(check_range_bounds(0, 1)),
]
font_color: list[tuple[float, ...] | float | str] | tuple[float, ...] | float | str
font_color: tuple[float, ...]
clear_bg: bool
metadata_key: str

def __init__(
self,
font_path: str | Path,
stopwords: tuple[str, ...] = ("the", "is", "in", "at", "of"),
augmentations: tuple[Literal["insertion", "swap", "deletion"] | None] = (None,),
augmentations: tuple[Literal["insertion", "swap", "deletion"] | None, ...] = (None,),
fraction_range: tuple[float, float] = (1.0, 1.0),
font_size_fraction_range: tuple[float, float] = (0.8, 0.9),
font_color: list[tuple[float, ...] | float | str] | tuple[float, ...] | float | str = "black",
font_color: tuple[float, ...] = (0, 0, 0), # black in RGB
clear_bg: bool = False,
metadata_key: str = "textimage_metadata",
p: float = 0.5,
Expand Down Expand Up @@ -174,7 +174,7 @@ def preprocess_metadata(

augmented_text = text if augmentation is None else self.random_aug(text, 0.5, choice=augmentation)

font_color = self.py_random.choice(self.font_color) if isinstance(self.font_color, list) else self.font_color
font_color = self.font_color

return {
"bbox_coords": (x_min, y_min, x_max, y_max),
Expand Down
6 changes: 3 additions & 3 deletions tests/aug_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,8 @@
dict(
font_path="./tests/files/LiberationSerif-Bold.ttf",
font_size_fraction_range=(0.8, 0.9),
font_color="red",
stopwords=[
font_color=(255, 0, 0), # red in RGB
stopwords=(
"a",
"the",
"is",
Expand All @@ -404,7 +404,7 @@
"for",
"at",
"by",
],
),
),
],
[A.GridElasticDeform, {"num_grid_xy": (10, 10), "magnitude": 10}],
Expand Down
14 changes: 7 additions & 7 deletions tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_convert_image_to_pil(image_shape):
"bbox_coords": (10, 10, 100, 50),
"text": "Test",
"font": font,
"font_color": 127, # Red color
"font_color": (127,), # Grayscale color
},
],
), # Grayscale image
Expand All @@ -139,7 +139,7 @@ def test_convert_image_to_pil(image_shape):
"bbox_coords": (10, 10, 100, 50),
"text": "Test",
"font": font,
"font_color": 127, # Red color
"font_color": (127,), # Single channel color
},
],
), # Single channel image
Expand All @@ -150,24 +150,24 @@ def test_convert_image_to_pil(image_shape):
"bbox_coords": (10, 10, 100, 50),
"text": "Test",
"font": font,
"font_color": (127, 127, 127), # Red color
"font_color": (127, 127, 127), # RGB color
},
{
"bbox_coords": (20, 20, 110, 60),
"text": "Test",
"font": font,
"font_color": "red", # Red color
"font_color": (255, 0, 0), # Red color
},
],
), # RGB image with tuple and string font color
), # RGB image with tuple colors
(
(100, 100, 5),
[
{
"bbox_coords": (20, 20, 110, 60),
"text": "Test",
"font": font,
"font_color": (127, 127, 127, 127, 127),
"font_color": (127, 127, 127, 127, 127), # 5-channel color
},
],
),
Expand All @@ -178,7 +178,7 @@ def test_convert_image_to_pil(image_shape):
"bbox_coords": (20, 20, 110, 60),
"text": "Test",
"font": font,
"font_color": "red",
"font_color": (255, 0, 0), # RGB color (will be padded for 5 channels)
},
],
),
Expand Down