diff --git a/pytorch_toolbelt/__init__.py b/pytorch_toolbelt/__init__.py index 89a30bcf9..a0f9eacc8 100644 --- a/pytorch_toolbelt/__init__.py +++ b/pytorch_toolbelt/__init__.py @@ -1,3 +1,3 @@ from __future__ import absolute_import -__version__ = "0.8.0" +__version__ = "0.8.1" diff --git a/pytorch_toolbelt/inference/ensembling.py b/pytorch_toolbelt/inference/ensembling.py index 30d9b6287..c80dd0d1c 100644 --- a/pytorch_toolbelt/inference/ensembling.py +++ b/pytorch_toolbelt/inference/ensembling.py @@ -2,9 +2,17 @@ import torch from torch import nn, Tensor -from typing import List, Union, Iterable, Optional, Dict, Tuple - -__all__ = ["ApplySoftmaxTo", "ApplySigmoidTo", "Ensembler", "PickModelOutput", "SelectByIndex", "average_checkpoints"] +from typing import List, Union, Iterable, Optional, Dict, Tuple, Mapping + +__all__ = [ + "ApplySoftmaxTo", + "ApplySigmoidTo", + "Ensembler", + "PickModelOutput", + "SelectByIndex", + "average_checkpoints", + "average_state_dicts", +] from pytorch_toolbelt.inference.tta import _deaugment_averaging @@ -163,53 +171,84 @@ def forward(self, outputs: Dict[str, Tensor]) -> Tensor: return outputs[self.target_key] -def average_checkpoints(inputs: List[str]) -> collections.OrderedDict: - """Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from: - https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16 +def average_state_dicts(state_dicts: List[Mapping[str, Tensor]]) -> Mapping[str, Tensor]: + """ + Averages multiple 'state_dict' + + """ + + keys = state_dicts[0].keys() + final_state_dict = collections.OrderedDict() + + for key in keys: + # Collect the values (tensors) for this key from all checkpoints + values = [sd[key] for sd in state_dicts] + + # Check the dtype of the first value (assuming all dtypes match) + first_val = values[0] + + if not all(v.shape == first_val.shape for v in values): + raise ValueError(f"Tensor shapes for key '{key}' are not consistent across checkpoints.") + + if first_val.dtype == torch.bool: + # For bool, ensure all are identical + for val in values[1:]: + if not torch.equal(val, first_val): + raise ValueError(f"Boolean values for key '{key}' differ between checkpoints.") + final_state_dict[key] = first_val # Use the first if all identical + + elif torch.is_floating_point(first_val): + # Average float values + stacked = torch.stack(values, dim=0) + target_dtype = stacked.dtype + accum_dtype = torch.promote_types(target_dtype, torch.float32) # Upcast to float32 if needed + averaged = stacked.to(accum_dtype).mean(dim=0).to(target_dtype) + final_state_dict[key] = averaged + + elif first_val.dtype in { + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.uint16, + torch.uint32, + torch.uint64, + }: + # Average integer values (using integer division) + stacked = torch.stack(values, dim=0) + summed = stacked.sum(dim=0, dtype=torch.int64) + averaged = summed // len(values) + final_state_dict[key] = averaged.to(first_val.dtype) + + else: + # If you have other special dtypes to handle, add logic here + # or simply copy the first value if that is your intended behavior. + raise TypeError(f"Unsupported dtype '{first_val.dtype}' encountered for key '{key}'.") + + return final_state_dict + + +def average_checkpoints(inputs: List[str], key=None, map_location="cpu", weights_only=True) -> collections.OrderedDict: + """Loads checkpoints from inputs and returns a model with averaged weights. + Args: inputs (List[str]): An iterable of string paths of checkpoints to load from. + key (str): An optional key to select a sub-dictionary from the checkpoint. + map_location (str): A string describing how to remap storage locations when loading the model. + weights_only (bool): If True, will only load the weights of the model. + Returns: A dict of string keys mapping to various values. The 'model' key from the returned dict should correspond to an OrderedDict mapping string parameter names to torch Tensors. """ - params_dict = collections.OrderedDict() - params_keys = None - new_state = None - num_models = len(inputs) - for fpath in inputs: - with open(fpath, "rb") as f: - state = torch.load( - f, - map_location="cpu", - ) - # Copies over the settings from the first checkpoint - if new_state is None: - new_state = state - model_params = state["model_state_dict"] - model_params_keys = list(model_params.keys()) - if params_keys is None: - params_keys = model_params_keys - elif params_keys != model_params_keys: - raise KeyError( - "For checkpoint {}, expected list of params: {}, " - "but found: {}".format(f, params_keys, model_params_keys) - ) - for k in params_keys: - p = model_params[k] - if isinstance(p, torch.HalfTensor): - p = p.float() - if k not in params_dict: - params_dict[k] = p.clone() - # NOTE: clone() is needed in case of p is a shared parameter - else: - params_dict[k] += p - averaged_params = collections.OrderedDict() - for k, v in params_dict.items(): - averaged_params[k] = v - if averaged_params[k].is_floating_point(): - averaged_params[k].div_(num_models) - else: - averaged_params[k] //= num_models - new_state["model_state_dict"] = averaged_params - return new_state + state_dicts = [torch.load(path, map_location="cpu", weights_only=weights_only) for path in inputs] + if key is not None: + state_dicts = [sd[key] for sd in state_dicts] + + avg_state_dict = average_state_dicts(state_dicts) + if key is not None: + avg_state_dict = {key: avg_state_dict} + + return avg_state_dict diff --git a/pytorch_toolbelt/inference/functional.py b/pytorch_toolbelt/inference/functional.py index 6a6324664..e299fdd43 100644 --- a/pytorch_toolbelt/inference/functional.py +++ b/pytorch_toolbelt/inference/functional.py @@ -247,7 +247,7 @@ def unpad_xyxy_bboxes(bboxes_tensor: torch.Tensor, pad, dim=-1): return bboxes_tensor - pad -def geometric_mean(x: Tensor, dim: int) -> Tensor: +def geometric_mean(x: Tensor, dim: int, keepdim=False) -> Tensor: """ Compute geometric mean along given dimension. This implementation assume values are in range (0...1) (Probabilities) @@ -258,10 +258,10 @@ def geometric_mean(x: Tensor, dim: int) -> Tensor: Returns: Tensor """ - return x.log().mean(dim=dim).exp() + return x.log().mean(dim=dim, keepdim=keepdim).exp() -def harmonic_mean(x: Tensor, dim: int, eps: float = 1e-6) -> Tensor: +def harmonic_mean(x: Tensor, dim: int, eps: float = 1e-6, keepdim=False) -> Tensor: """ Compute harmonic mean along given dimension. @@ -273,7 +273,7 @@ def harmonic_mean(x: Tensor, dim: int, eps: float = 1e-6) -> Tensor: Tensor """ x = torch.reciprocal(x.clamp_min(eps)) - x = torch.mean(x, dim=dim) + x = torch.mean(x, dim=dim, keepdim=keepdim) x = torch.reciprocal(x.clamp_min(eps)) return x diff --git a/pytorch_toolbelt/losses/functional.py b/pytorch_toolbelt/losses/functional.py index bec6a786f..054735bc2 100644 --- a/pytorch_toolbelt/losses/functional.py +++ b/pytorch_toolbelt/losses/functional.py @@ -16,7 +16,6 @@ ] -@torch.cuda.amp.autocast(False) def focal_loss_with_logits( output: torch.Tensor, target: torch.Tensor, @@ -58,51 +57,52 @@ def focal_loss_with_logits( output = output.float() target = target.float() - if activation == "sigmoid": - p = torch.sigmoid(output) - else: - p = torch.softmax(output, dim=softmax_dim) - - ce_loss = F.binary_cross_entropy_with_logits(output, target, reduction="none") - pt = p * target + (1 - p) * (1 - target) - - # compute the loss - if reduced_threshold is None: - focal_term = (1.0 - pt).pow(gamma) - else: - focal_term = ((1.0 - pt) / (1 - reduced_threshold)).pow( - gamma - ) # the focal term continuity breaks when reduced_threshold not equal to 0.5. At pt equal to reduced_threshold, the value of piecewise function of focal term should be 1 from both sides . - focal_term = torch.masked_fill(focal_term, pt < reduced_threshold, 1) - - loss = focal_term * ce_loss - - if alpha is not None: - loss *= alpha * target + (1 - alpha) * (1 - target) - - if class_weights is not None: - # class_weights is of shape [C] - # Loss is of shape [B,C ...] - # Reshape class_weights to [1, C, ...] - class_weights = class_weights.view(1, -1, *(1 for _ in range(loss.dim() - 2))) - loss *= class_weights + with torch.amp.autocast(device_type=output.device.type, enabled=False): + if activation == "sigmoid": + p = torch.sigmoid(output) + else: + p = torch.softmax(output, dim=softmax_dim) + + ce_loss = F.binary_cross_entropy_with_logits(output, target, reduction="none") + pt = p * target + (1 - p) * (1 - target) + + # compute the loss + if reduced_threshold is None: + focal_term = (1.0 - pt).pow(gamma) + else: + focal_term = ((1.0 - pt) / (1 - reduced_threshold)).pow( + gamma + ) # the focal term continuity breaks when reduced_threshold not equal to 0.5. At pt equal to reduced_threshold, the value of piecewise function of focal term should be 1 from both sides . + focal_term = torch.masked_fill(focal_term, pt < reduced_threshold, 1) + + loss = focal_term * ce_loss + + if alpha is not None: + loss *= alpha * target + (1 - alpha) * (1 - target) + + if class_weights is not None: + # class_weights is of shape [C] + # Loss is of shape [B,C ...] + # Reshape class_weights to [1, C, ...] + class_weights = class_weights.view(1, -1, *(1 for _ in range(loss.dim() - 2))) + loss *= class_weights + + if ignore_index is not None: + ignore_mask = target.eq(ignore_index) + loss = torch.masked_fill(loss, ignore_mask, 0) + if normalized: + focal_term = torch.masked_fill(focal_term, ignore_mask, 0) - if ignore_index is not None: - ignore_mask = target.eq(ignore_index) - loss = torch.masked_fill(loss, ignore_mask, 0) if normalized: - focal_term = torch.masked_fill(focal_term, ignore_mask, 0) - - if normalized: - norm_factor = focal_term.sum(dtype=torch.float32).clamp_min(eps) - loss /= norm_factor - - if reduction == "mean": - loss = loss.mean() - if reduction == "sum": - loss = loss.sum() - if reduction == "batchwise_mean": - loss = loss.sum(dim=0) + norm_factor = focal_term.sum(dtype=torch.float32).clamp_min(eps) + loss /= norm_factor + + if reduction == "mean": + loss = loss.mean() + if reduction == "sum": + loss = loss.sum() + if reduction == "batchwise_mean": + loss = loss.sum(dim=0) return loss diff --git a/pytorch_toolbelt/losses/quality_focal_loss.py b/pytorch_toolbelt/losses/quality_focal_loss.py index 95fa8763b..1a91c8bcf 100644 --- a/pytorch_toolbelt/losses/quality_focal_loss.py +++ b/pytorch_toolbelt/losses/quality_focal_loss.py @@ -20,7 +20,6 @@ def __init__(self, beta: float = 2, reduction="mean"): self.beta = beta self.reduction = reduction - @torch.cuda.amp.autocast(False) def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: """ Compute quality focal loss @@ -32,15 +31,16 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: predictions = predictions.float() targets = targets.float() - bce = torch.nn.functional.binary_cross_entropy_with_logits(predictions, targets, reduction="none") - focal_term = torch.nn.functional.l1_loss(predictions.sigmoid(), targets, reduction="none").pow_(self.beta) - loss = focal_term * bce - - if self.reduction == "mean": - return loss.mean() - if self.reduction == "sum": - return loss.sum() - if self.reduction == "normalized": - return loss.sum() / focal_term.sum() + with torch.amp.autocast(device_type=predictions.device.type, enabled=False): + bce = torch.nn.functional.binary_cross_entropy_with_logits(predictions, targets, reduction="none") + focal_term = torch.nn.functional.l1_loss(predictions.sigmoid(), targets, reduction="none").pow_(self.beta) + loss = focal_term * bce + + if self.reduction == "mean": + return loss.mean() + if self.reduction == "sum": + return loss.sum() + if self.reduction == "normalized": + return loss.sum() / focal_term.sum() return loss diff --git a/pytorch_toolbelt/modules/interfaces.py b/pytorch_toolbelt/modules/interfaces.py index b5799e780..4cd1c0f3d 100644 --- a/pytorch_toolbelt/modules/interfaces.py +++ b/pytorch_toolbelt/modules/interfaces.py @@ -4,6 +4,7 @@ import numpy as np import torch.jit +from torch import nn, Tensor __all__ = [ "FeatureMapsSpecification", @@ -13,10 +14,6 @@ "AbstractEncoder", ] -from torch import nn, Tensor - -from pytorch_toolbelt.utils import pytorch_toolbelt_deprecated - @dataclasses.dataclass class FeatureMapsSpecification: @@ -61,8 +58,7 @@ class HasInputFeaturesSpecification(Protocol): """ @torch.jit.unused - def get_input_spec(self) -> FeatureMapsSpecification: - ... + def get_input_spec(self) -> FeatureMapsSpecification: ... class HasOutputFeaturesSpecification(Protocol): @@ -71,8 +67,7 @@ class HasOutputFeaturesSpecification(Protocol): """ @torch.jit.unused - def get_output_spec(self) -> FeatureMapsSpecification: - ... + def get_output_spec(self) -> FeatureMapsSpecification: ... class AbstractEncoder(nn.Module, HasOutputFeaturesSpecification): @@ -108,8 +103,7 @@ def __init__(self, input_spec: FeatureMapsSpecification): @abstractmethod def forward( self, feature_maps: List[Tensor], output_size: Union[Tuple[int, int], torch.Size, None] = None - ) -> Union[Tensor, Tuple[Tensor, ...], List[Tensor], Mapping[str, Tensor]]: - ... + ) -> Union[Tensor, Tuple[Tensor, ...], List[Tensor], Mapping[str, Tensor]]: ... @torch.jit.unused def apply_to_final_layer(self, func: Callable[[nn.Module], None]): diff --git a/pytorch_toolbelt/utils/distributed.py b/pytorch_toolbelt/utils/distributed.py index 3c68d7428..2710e6b41 100644 --- a/pytorch_toolbelt/utils/distributed.py +++ b/pytorch_toolbelt/utils/distributed.py @@ -9,7 +9,7 @@ import numpy as np import torch from torch import Tensor - +from contextlib import contextmanager import torch.distributed as dist from pytorch_toolbelt.utils.bucket_assignment import ( @@ -32,6 +32,7 @@ "reduce_dict_sum", "split_across_nodes", "master_node_only", + "master_node_first", ] logger = logging.getLogger("pytorch_toolbelt.utils.distributed") @@ -61,7 +62,7 @@ def __enter__(self): if self.dist_is_available and self.world_size > 1: if not self.dist_is_initialized: logger.info(f"Setting CUDA device {self.device} for rank {self.local_rank}/{self.world_size}") - torch.distributed.init_process_group(backend="nccl", world_size=self.world_size, rank=self.local_rank) + torch.distributed.init_process_group(backend="nccl", world_size=self.world_size, rank=self.local_rank, device_id=self.device) return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -340,3 +341,28 @@ def wrapper(*args, **kwargs): return None return wrapper + + +@contextmanager +def master_node_first(local_rank: int | None = None): + """ + Execute some code on master node first, then wait for all other nodes to finish. + + Usage: + with master_node_first(): + ... + + """ + if local_rank is None: + local_rank = get_rank() + + if local_rank > 0: + dist.barrier() + yield + if local_rank == 0: + if not dist.is_available(): + return + if not dist.is_initialized(): + return + else: + dist.barrier() diff --git a/pytorch_toolbelt/utils/torch_utils.py b/pytorch_toolbelt/utils/torch_utils.py index 06e9ffb48..433dcf5b9 100644 --- a/pytorch_toolbelt/utils/torch_utils.py +++ b/pytorch_toolbelt/utils/torch_utils.py @@ -17,6 +17,8 @@ from .support import pytorch_toolbelt_deprecated +logger = logging.getLogger("pytorch_toolbelt.utils") + __all__ = [ "argmax_over_dim_0", "argmax_over_dim_1", @@ -47,6 +49,7 @@ "get_collate_for_dataset", "get_non_wrapped_model", "container_to_tensor", + "convert_2d_to_3d", ] @@ -176,7 +179,7 @@ def to_tensor(x, dtype=None) -> torch.Tensor: x = x.type(dtype) return x if isinstance(x, (list, tuple)): - x = np.ndarray(x) + x = np.array(x) x = torch.from_numpy(x) if dtype is not None: x = x.type(dtype) @@ -284,10 +287,34 @@ def maybe_cuda(x: Union[torch.Tensor, nn.Module]) -> Union[torch.Tensor, nn.Modu return x -logger = logging.getLogger("pytorch_toolbelt.utils") - +@dataclasses.dataclass +class TransferWeightsOuptut: + """ + Output of transfer_weights function. Holds information about how many layers were loaded, skipped, etc. + Can be used to get detailed information about how many layers were loaded from checkpoint to model. + """ -def transfer_weights(model: nn.Module, model_state_dict: collections.OrderedDict, incompatible_shape_action="skip"): + loaded_layers: List[str] + skipped_layers: List[str] + missing_layers_in_model: List[str] + missing_layers_in_checkpoint: List[str] + + def __repr__(self): + total_layers_in_checkpoint = ( + len(self.loaded_layers) + len(self.missing_layers_in_model) + len(self.skipped_layers) + ) + total_layers_in_model = ( + len(self.loaded_layers) + len(self.missing_layers_in_checkpoint) + len(self.skipped_layers) + ) + loaded_layers_percentage = 100.0 * len(self.loaded_layers) / total_layers_in_checkpoint + skipped_layers_percentage = 100.0 * len(self.skipped_layers) / total_layers_in_checkpoint + model_initialized_percentage = 100.0 * len(self.loaded_layers) / total_layers_in_model + return f"TransferWeightsOuptut({model_initialized_percentage=:.2f}, {loaded_layers_percentage=:.2f}, {skipped_layers_percentage=:.2f})" + + +def transfer_weights( + model: nn.Module, model_state_dict: collections.OrderedDict, incompatible_shape_action="skip" +) -> TransferWeightsOuptut: """ Copy weights from state dict to model, skipping layers that are incompatible. This method is helpful if you are doing some model surgery and want to load @@ -295,14 +322,17 @@ def transfer_weights(model: nn.Module, model_state_dict: collections.OrderedDict :param model: Model to load weights into :param model_state_dict: Model state dict to load weights from :param incompatible_shape_action: What to do if shape of weight tensor is incompatible. - Possible values are: - - "skip" - Skip loading this tensor - - "match_mean_std" - Initialize tensor with random values with same mean and std as source tensor - :return: None + Possible values are: + - "skip" - Skip loading this tensor + - "match_mean_std" - Initialize tensor with random values with same mean and std as source tensor + :return: Instance of TransferWeightsOuptut """ existing_model_state_dict = model.state_dict() - loaded_layers = 0 + loaded_layer_names = [] + skipped_layers_names = [] + layers_not_in_model = list(set(existing_model_state_dict.keys()) - set(model_state_dict.keys())) + layers_not_in_checkpoint = list(set(model_state_dict.keys()) - set(existing_model_state_dict.keys())) for name, value in model_state_dict.items(): if name not in existing_model_state_dict: @@ -314,6 +344,7 @@ def transfer_weights(model: nn.Module, model_state_dict: collections.OrderedDict existing_value = existing_model_state_dict[name] if value.shape != existing_value.shape: if incompatible_shape_action == "skip": + skipped_layers_names.append(name) logger.debug( f"transfer_weights skipped loading weights for key {name}, because of checkpoint has shape {value.shape} and model has shape {existing_model_state_dict[name].shape}" ) @@ -330,14 +361,15 @@ def transfer_weights(model: nn.Module, model_state_dict: collections.OrderedDict try: model.load_state_dict(collections.OrderedDict([(name, value)]), strict=False) - loaded_layers += 1 + loaded_layer_names.append(name) except Exception as e: logger.debug(f"transfer_weights skipped loading weights for key {name}, because of error: {e}") - percentage_of_layers_from_checkpoint = loaded_layers / len(model_state_dict) * 100 - percentage_of_layers_in_model = loaded_layers / len(existing_model_state_dict) * 100 - logger.info( - f"Transferred {percentage_of_layers_from_checkpoint:.2f}% of layers from checkpoint to model, filling {percentage_of_layers_in_model:.2f}% of model layers" + return TransferWeightsOuptut( + loaded_layers=loaded_layer_names, + skipped_layers=skipped_layers_names, + missing_layers_in_model=layers_not_in_model, + missing_layers_in_checkpoint=layers_not_in_checkpoint, ) @@ -485,3 +517,91 @@ def get_non_wrapped_model(model: nn.Module) -> nn.Module: model = model.module return model + + +def convert_2d_to_3d(model: nn.Module) -> nn.Module: + """ + Recursively convert all 2d layers in `model` to their 3d versions. + This method converts 2D CNN model to 3D version. Important note - models/layers with non-trivial forward() method + probably not going to work (LayerNorm2d or GlobalResponseNormalization for instance) + Replicates the existing Conv2d weights along the 3rd dimension (depth=1 by default) and scales them accordingly. + + :param model: Model to convert + """ + for name, module in model.named_children(): + + # If we find a Conv2d, replace it with a Conv3d. + if isinstance(module, nn.Conv2d): + # -------------------------------------------- + # 1) Check that the 2D kernel is square + # -------------------------------------------- + if module.kernel_size[0] != module.kernel_size[1]: + raise ValueError( + f"Non-square kernel detected: {module.kernel_size}. " + "This example only handles square kernels (k, k)." + ) + k = module.kernel_size[0] + + # -------------------------------------------- + # 2) Build a new Conv3d with kernel_size = (k, k, k) + # using the same hyperparameters as best we can + # -------------------------------------------- + new_conv = nn.Conv3d( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=(k, k, k), + # For stride, padding, and dilation, we replicate + # the 2D values in each dimension: + stride=(module.stride[0], module.stride[0], module.stride[1]), + padding=(module.padding[0], module.padding[0], module.padding[1]), + dilation=(module.dilation[0], module.dilation[0], module.dilation[1]), + groups=module.groups, + bias=(module.bias is not None), + ) + + # -------------------------------------------- + # 3) Copy and replicate the 2D weights -> 3D + # old_weight shape: (out_c, in_c, k, k) + # new_weight shape: (out_c, in_c, k, k, k) + # -------------------------------------------- + with torch.no_grad(): + old_weight = module.weight # shape: (out_c, in_c, k, k) + # Expand along a new depth dimension + old_weight_3d = old_weight.unsqueeze(2) # (out_c, in_c, 1, k, k) + old_weight_3d = old_weight_3d.repeat(1, 1, k, 1, 1).div(k) # (out_c, in_c, k, k, k) + new_conv.weight.copy_(old_weight_3d) + + if module.bias is not None: + new_conv.bias.copy_(module.bias) + + # Replace the old Conv2d with our new Conv3d + setattr(model, name, new_conv) + + # If we find a BatchNorm2d, replace it with a BatchNorm3d. + elif isinstance(module, nn.BatchNorm2d): + new_bn = nn.BatchNorm3d( + num_features=module.num_features, + eps=module.eps, + momentum=module.momentum, + affine=module.affine, + track_running_stats=module.track_running_stats, + ) + + # Copy running statistics and affine parameters + with torch.no_grad(): + if module.affine: + new_bn.weight.copy_(module.weight) + new_bn.bias.copy_(module.bias) + new_bn.running_mean.copy_(module.running_mean) + new_bn.running_var.copy_(module.running_var) + + # Replace the BatchNorm2d with BatchNorm3d + setattr(model, name, new_bn) + elif isinstance(module, nn.Dropout2d): + # Replace with Dropout3d + setattr(model, name, nn.Dropout3d(p=module.p, inplace=module.inplace)) + else: + # Recursively convert children + convert_2d_to_3d(module) + + return model diff --git a/pytorch_toolbelt/utils/visualization.py b/pytorch_toolbelt/utils/visualization.py index 3c7e9bdfa..16f258d5e 100644 --- a/pytorch_toolbelt/utils/visualization.py +++ b/pytorch_toolbelt/utils/visualization.py @@ -295,6 +295,88 @@ def vstack_autopad(images: Iterable[np.ndarray], pad_value: int = 0, spacing: in return np.vstack(padded_images) +def wrap_text_to_width(text, font_face, font_scale, thickness, max_width): + """ + Splits a text string into multiple lines so that each line's width (using the provided text parameters) + does not exceed max_width. + + Priority is given to whitespace when wrapping; if there are no spaces, + it wraps at the character level. + After lines are formed, leading and trailing whitespace is stripped. + + :param text: The input text to wrap. + :param font_face: OpenCV font face (e.g., cv2.FONT_HERSHEY_SIMPLEX). + :param font_scale: Font scale factor that is multiplied by the base font size. + :param thickness: Thickness of the strokes used to draw text. + :param max_width: Maximum allowed width (in pixels) for one line of text. + :return: A list of text lines (strings). + """ + + # Early exit for empty text + if not text.strip(): + return [] + + # If the text has no whitespace, we switch to character-level wrapping + has_whitespace = " " in text + + # Depending on presence of whitespace, choose how to split the text initially + if has_whitespace: + tokens = text.split(" ") + else: + # No whitespace - treat every character as a separate "token" + tokens = list(text) + + lines = [] + current_line = "" + + def fits_in_width(candidate_text): + """Check if candidate_text fits within max_width using cv2.getTextSize.""" + size, _ = cv2.getTextSize(candidate_text, font_face, font_scale, thickness) + return size[0] <= max_width + + for i, token in enumerate(tokens): + # If there is whitespace, tokens are words; otherwise tokens are individual characters. + # We add a space if it's word-based wrapping (has_whitespace and not the very first word) + if has_whitespace: + tentative_line = (current_line + " " + token) if current_line else token + else: + # If we are in character-mode, do not prepend a space + tentative_line = current_line + token + + # Check if the tentative line fits + if fits_in_width(tentative_line): + current_line = tentative_line + else: + # If it doesn't fit, we need to finalize the current_line and start a new one. + if current_line: + lines.append(current_line.strip()) + # In word-based mode, if a single token (word) doesn't fit on an empty line, + # we might need to break it further by characters: + if has_whitespace and not fits_in_width(token): + # Break this word by character + char_line = "" + for ch in token: + if fits_in_width(char_line + ch): + char_line += ch + else: + if char_line: + lines.append(char_line.strip()) + char_line = ch + current_line = char_line # start the next line with leftover + else: + # Start new current_line with the token + current_line = token if not has_whitespace else token + + # Append any leftover text in current_line + if current_line: + lines.append(current_line.strip()) + + # Strip each line (remove leading/trailing whitespace) just in case + lines = [line.strip() for line in lines if line.strip()] + + return lines + + def vstack_header( image: np.ndarray, title: str, @@ -302,22 +384,43 @@ def vstack_header( text_color=(242, 248, 248), text_thickness: int = 2, text_scale=1.5, + wrap_text: bool = False, + text_font_face=cv2.FONT_HERSHEY_PLAIN, ) -> np.ndarray: (rows, cols) = image.shape[:2] - title_image = np.zeros((30, cols, 3), dtype=np.uint8) - title_image[:] = bg_color - cv2.putText( - title_image, - title, - (10, 24), - fontFace=cv2.FONT_HERSHEY_PLAIN, - fontScale=text_scale, - color=text_color, - thickness=text_thickness, - lineType=cv2.LINE_AA, - ) + image_width = image.shape[1] + (width, height), baseline = cv2.getTextSize(title, text_font_face, text_scale, text_thickness) + padding_left = 10 + padding_right = 10 + row_height = int(height * 2 + 0.5) + + if wrap_text and (width + padding_left + padding_right) > image_width: + lines = wrap_text_to_width( + title, text_font_face, text_scale, text_thickness, image_width - padding_left - padding_right + ) + else: + lines = [title] + + title_images = [] + + for line in lines: + title_image = np.zeros((row_height, cols, 3), dtype=np.uint8) + title_image[:] = bg_color + cv2.putText( + title_image, + line, + (padding_left, row_height - int(height * 0.5)), + fontFace=text_font_face, + fontScale=text_scale, + color=text_color, + thickness=text_thickness, + lineType=cv2.LINE_AA, + ) + title_images.append(title_image) + + title_image = np.vstack(title_images) return vstack_autopad([title_image, image]) @@ -337,6 +440,6 @@ def grid_stack( image_rows = [] for r in range(rows): - image_rows.append(hstack_autopad(images[r * cols : (r + 1) * cols], bg_color=bg_color, spacing=spacing)) + image_rows.append(hstack_autopad(images[r * cols : (r + 1) * cols], pad_value=bg_color, spacing=spacing)) - return vstack_autopad(image_rows, bg_color=bg_color, spacing=spacing) + return vstack_autopad(image_rows, pad_value=bg_color, spacing=spacing) diff --git a/setup.py b/setup.py index 4aee12e4d..c6d88369c 100644 --- a/setup.py +++ b/setup.py @@ -117,7 +117,7 @@ def load_readme(): def get_test_requirements(): - requirements = ["pytest", "black==23.3.0", "timm==0.6.7", "matplotlib"] + requirements = ["pytest", "black~=24.8.0", "timm==0.6.7", "matplotlib"] return requirements diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 0512113ab..9512289df 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -1,6 +1,7 @@ +import matplotlib.pyplot as plt import numpy as np -from pytorch_toolbelt.utils import plot_confusion_matrix, plot_heatmap +from pytorch_toolbelt.utils import plot_confusion_matrix, plot_heatmap, vstack_header def test_plot_confusion_matrix(): @@ -15,3 +16,12 @@ def test_plot_heatmap(): cm = np.random.randn(20, 30) plot_heatmap(cm, title="Test", x_label="30", y_label="20", fname="test_plot_heatmap.png", noshow=False) + + +def test_vstack_header(): + title = "A very long header text that would not fit in a single line and should be wrapped into multiple lines to fit the plot" + image = np.full((256, 256, 3), 255, dtype=np.uint8) + image2 = vstack_header(image, title, text_scale=2, wrap_text=True) + plt.figure() + plt.imshow(image2) + plt.show()