Skip to content

Commit 9c41743

Browse files
authored
Add stain augment (albumentations-team#2337)
* Added tests for SimpleNMF * Pass tests in functional * Cleanup * Tests pass * Added sklearn to tests * Refactoring * Added to Reame
1 parent c650a54 commit 9c41743

File tree

8 files changed

+800
-2
lines changed

8 files changed

+800
-2
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ Pixel-level transforms will change just an input image and will leave any additi
204204
- [GaussNoise](https://explore.albumentations.ai/transform/GaussNoise)
205205
- [GaussianBlur](https://explore.albumentations.ai/transform/GaussianBlur)
206206
- [GlassBlur](https://explore.albumentations.ai/transform/GlassBlur)
207+
- [HEStain](https://explore.albumentations.ai/transform/HEStain)
207208
- [HistogramMatching](https://explore.albumentations.ai/transform/HistogramMatching)
208209
- [HueSaturationValue](https://explore.albumentations.ai/transform/HueSaturationValue)
209210
- [ISONoise](https://explore.albumentations.ai/transform/ISONoise)

albumentations/augmentations/functional.py

+311
Original file line numberDiff line numberDiff line change
@@ -3136,3 +3136,314 @@ def get_mud_params(
31363136
"mud": mud.astype(np.float32),
31373137
"non_mud": non_mud.astype(np.float32),
31383138
}
3139+
3140+
3141+
# Standard reference H&E stain matrices
3142+
STAIN_MATRICES = {
3143+
"ruifrok": np.array(
3144+
[ # Ruifrok & Johnston standard reference
3145+
[0.644211, 0.716556, 0.266844], # Hematoxylin
3146+
[0.092789, 0.954111, 0.283111], # Eosin
3147+
],
3148+
),
3149+
"macenko": np.array(
3150+
[ # Macenko's reference
3151+
[0.5626, 0.7201, 0.4062],
3152+
[0.2159, 0.8012, 0.5581],
3153+
],
3154+
),
3155+
"standard": np.array(
3156+
[ # Standard bright-field microscopy
3157+
[0.65, 0.70, 0.29],
3158+
[0.07, 0.99, 0.11],
3159+
],
3160+
),
3161+
"high_contrast": np.array(
3162+
[ # Enhanced contrast
3163+
[0.55, 0.88, 0.11],
3164+
[0.12, 0.86, 0.49],
3165+
],
3166+
),
3167+
"h_heavy": np.array(
3168+
[ # Hematoxylin dominant
3169+
[0.75, 0.61, 0.32],
3170+
[0.04, 0.93, 0.36],
3171+
],
3172+
),
3173+
"e_heavy": np.array(
3174+
[ # Eosin dominant
3175+
[0.60, 0.75, 0.28],
3176+
[0.17, 0.95, 0.25],
3177+
],
3178+
),
3179+
"dark": np.array(
3180+
[ # Darker staining
3181+
[0.78, 0.55, 0.28],
3182+
[0.09, 0.97, 0.21],
3183+
],
3184+
),
3185+
"light": np.array(
3186+
[ # Lighter staining
3187+
[0.57, 0.71, 0.38],
3188+
[0.15, 0.89, 0.42],
3189+
],
3190+
),
3191+
}
3192+
3193+
3194+
def rgb_to_optical_density(img: np.ndarray, eps: float = 1e-6) -> np.ndarray:
3195+
max_value = MAX_VALUES_BY_DTYPE[img.dtype]
3196+
pixel_matrix = img.reshape(-1, 3).astype(np.float32)
3197+
pixel_matrix = np.maximum(pixel_matrix / max_value, eps)
3198+
return -np.log(pixel_matrix)
3199+
3200+
3201+
def normalize_vectors(vectors: np.ndarray) -> np.ndarray:
3202+
norms = np.sqrt(np.sum(vectors**2, axis=1, keepdims=True))
3203+
return vectors / norms
3204+
3205+
3206+
def get_normalizer(method: Literal["vahadane", "macenko"]) -> StainNormalizer:
3207+
"""Get stain normalizer based on method."""
3208+
return VahadaneNormalizer() if method == "vahadane" else MacenkoNormalizer()
3209+
3210+
3211+
class StainNormalizer:
3212+
"""Base class for stain normalizers."""
3213+
3214+
def __init__(self) -> None:
3215+
self.stain_matrix_target = None
3216+
3217+
def fit(self, img: np.ndarray) -> None:
3218+
"""Extract stain matrix from image."""
3219+
raise NotImplementedError
3220+
3221+
3222+
class SimpleNMF:
3223+
def __init__(self, n_iter: int = 100):
3224+
self.n_iter = n_iter
3225+
# Initialize with standard H&E colors from Ruifrok
3226+
self.initial_colors = np.array(
3227+
[
3228+
[0.644211, 0.716556, 0.266844], # Hematoxylin
3229+
[0.092789, 0.954111, 0.283111], # Eosin
3230+
],
3231+
dtype=np.float32,
3232+
)
3233+
3234+
def fit_transform(self, optical_density: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
3235+
# Start with known H&E colors
3236+
stain_colors = self.initial_colors.copy()
3237+
3238+
# Initialize concentrations based on projection onto initial colors
3239+
# This gives us a physically meaningful starting point
3240+
stain_colors_normalized = normalize_vectors(stain_colors)
3241+
stain_concentrations = np.maximum(optical_density @ stain_colors_normalized.T, 0)
3242+
3243+
# Iterative updates with careful normalization
3244+
eps = 1e-6
3245+
for _ in range(self.n_iter):
3246+
# Update concentrations
3247+
numerator = optical_density @ stain_colors.T
3248+
denominator = stain_concentrations @ (stain_colors @ stain_colors.T)
3249+
stain_concentrations *= numerator / (denominator + eps)
3250+
3251+
# Ensure non-negativity
3252+
stain_concentrations = np.maximum(stain_concentrations, 0)
3253+
3254+
# Update colors
3255+
numerator = stain_concentrations.T @ optical_density
3256+
denominator = (stain_concentrations.T @ stain_concentrations) @ stain_colors
3257+
stain_colors *= numerator / (denominator + eps)
3258+
3259+
# Ensure non-negativity and normalize
3260+
stain_colors = np.maximum(stain_colors, 0)
3261+
stain_colors = normalize_vectors(stain_colors)
3262+
3263+
return stain_concentrations, stain_colors
3264+
3265+
3266+
def order_stains_combined(stain_colors: np.ndarray) -> tuple[int, int]:
3267+
"""Order stains using a combination of methods.
3268+
3269+
This combines both angular information and spectral characteristics
3270+
for more robust identification.
3271+
"""
3272+
# Normalize stain vectors
3273+
stain_colors = normalize_vectors(stain_colors)
3274+
3275+
# Calculate angles (Macenko)
3276+
angles = np.mod(np.arctan2(stain_colors[:, 1], stain_colors[:, 0]), np.pi)
3277+
3278+
# Calculate spectral ratios (Ruifrok)
3279+
blue_ratio = stain_colors[:, 2] / (np.sum(stain_colors, axis=1) + 1e-6)
3280+
red_ratio = stain_colors[:, 0] / (np.sum(stain_colors, axis=1) + 1e-6)
3281+
3282+
# Combine scores
3283+
# High angle and high blue ratio indicates Hematoxylin
3284+
# Low angle and high red ratio indicates Eosin
3285+
scores = angles * blue_ratio - red_ratio
3286+
3287+
hematoxylin_idx = np.argmax(scores)
3288+
eosin_idx = 1 - hematoxylin_idx
3289+
3290+
return hematoxylin_idx, eosin_idx
3291+
3292+
3293+
class VahadaneNormalizer(StainNormalizer):
3294+
def fit(self, img: np.ndarray) -> None:
3295+
optical_density = rgb_to_optical_density(img)
3296+
3297+
nmf = SimpleNMF(n_iter=100)
3298+
_, stain_colors = nmf.fit_transform(optical_density)
3299+
3300+
# Use combined method for robust stain ordering
3301+
hematoxylin_idx, eosin_idx = order_stains_combined(stain_colors)
3302+
3303+
self.stain_matrix_target = np.array(
3304+
[
3305+
stain_colors[hematoxylin_idx],
3306+
stain_colors[eosin_idx],
3307+
],
3308+
)
3309+
3310+
3311+
class MacenkoNormalizer(StainNormalizer):
3312+
"""Macenko stain normalizer with optimized computations."""
3313+
3314+
def __init__(self, angular_percentile: float = 99):
3315+
super().__init__()
3316+
self.angular_percentile = angular_percentile
3317+
3318+
def fit(self, img: np.ndarray, angular_percentile: float = 99) -> None:
3319+
"""Extract H&E stain matrix using optimized Macenko's method."""
3320+
# Step 1: Convert RGB to optical density (OD) space
3321+
optical_density = rgb_to_optical_density(img)
3322+
3323+
# Step 2: Remove background pixels
3324+
od_threshold = 0.05
3325+
threshold_mask = (optical_density > od_threshold).any(axis=1)
3326+
tissue_density = optical_density[threshold_mask]
3327+
3328+
if len(tissue_density) < 1:
3329+
raise ValueError(f"No tissue pixels found (threshold={od_threshold})")
3330+
3331+
# Step 3: Compute covariance matrix
3332+
tissue_density = np.ascontiguousarray(tissue_density, dtype=np.float32)
3333+
od_covariance = cv2.calcCovarMatrix(
3334+
tissue_density,
3335+
None,
3336+
cv2.COVAR_NORMAL | cv2.COVAR_ROWS | cv2.COVAR_SCALE,
3337+
)[0]
3338+
3339+
# Step 4: Get principal components
3340+
eigenvalues, eigenvectors = cv2.eigen(od_covariance)[1:]
3341+
idx = np.argsort(eigenvalues.ravel())[-2:]
3342+
principal_eigenvectors = np.ascontiguousarray(eigenvectors[:, idx], dtype=np.float32)
3343+
3344+
# Step 5: Project onto eigenvector plane
3345+
plane_coordinates = tissue_density @ principal_eigenvectors
3346+
3347+
# Step 6: Find angles of extreme points
3348+
polar_angles = np.arctan2(
3349+
plane_coordinates[:, 1],
3350+
plane_coordinates[:, 0],
3351+
)
3352+
3353+
# Get robust angle estimates
3354+
hematoxylin_angle = np.percentile(polar_angles, 100 - angular_percentile)
3355+
eosin_angle = np.percentile(polar_angles, angular_percentile)
3356+
3357+
# Step 7: Convert angles back to RGB space
3358+
hem_cos, hem_sin = np.cos(hematoxylin_angle), np.sin(hematoxylin_angle)
3359+
eos_cos, eos_sin = np.cos(eosin_angle), np.sin(eosin_angle)
3360+
3361+
angle_to_vector = np.array(
3362+
[[hem_cos, hem_sin], [eos_cos, eos_sin]],
3363+
dtype=np.float32,
3364+
)
3365+
stain_vectors = cv2.gemm(
3366+
angle_to_vector,
3367+
principal_eigenvectors.T,
3368+
1,
3369+
None,
3370+
0,
3371+
)
3372+
3373+
# Step 8: Ensure non-negativity by taking absolute values
3374+
# This is valid because stain vectors represent absorption coefficients
3375+
stain_vectors = np.abs(stain_vectors)
3376+
3377+
# Step 9: Normalize vectors to unit length
3378+
stain_vectors = stain_vectors / np.sqrt(np.sum(stain_vectors**2, axis=1, keepdims=True))
3379+
3380+
# Step 10: Order vectors as [hematoxylin, eosin]
3381+
# Hematoxylin typically has larger red component
3382+
self.stain_matrix_target = stain_vectors if stain_vectors[0, 0] > stain_vectors[1, 0] else stain_vectors[::-1]
3383+
3384+
3385+
def get_tissue_mask(img: np.ndarray, threshold: float = 0.85) -> np.ndarray:
3386+
"""Get binary mask of tissue regions based on luminosity.
3387+
3388+
Args:
3389+
img: RGB image in float32 format, range [0, 1]
3390+
threshold: Luminosity threshold. Pixels with luminosity below this value
3391+
are considered tissue. Range: 0 to 1. Default: 0.85
3392+
3393+
Returns:
3394+
Binary mask where True indicates tissue regions
3395+
"""
3396+
# Convert to grayscale using RGB weights: R*0.299 + G*0.587 + B*0.114
3397+
luminosity = img[..., 0] * 0.299 + img[..., 1] * 0.587 + img[..., 2] * 0.114
3398+
3399+
# Tissue is darker, so we want pixels below threshold
3400+
mask = luminosity < threshold
3401+
3402+
return mask.reshape(-1)
3403+
3404+
3405+
@clipped
3406+
@float32_io
3407+
def apply_he_stain_augmentation(
3408+
img: np.ndarray,
3409+
stain_matrix: np.ndarray,
3410+
scale_factors: np.ndarray,
3411+
shift_values: np.ndarray,
3412+
augment_background: bool,
3413+
) -> np.ndarray:
3414+
# Step 1: Convert RGB to optical density space
3415+
optical_density = rgb_to_optical_density(img)
3416+
3417+
# Step 2: Calculate stain concentrations using regularized pseudo-inverse
3418+
stain_matrix = np.ascontiguousarray(stain_matrix, dtype=np.float32)
3419+
3420+
# Add small regularization term for numerical stability
3421+
regularization = 1e-6
3422+
stain_correlation = stain_matrix @ stain_matrix.T + regularization * np.eye(2)
3423+
density_projection = stain_matrix @ optical_density.T
3424+
3425+
try:
3426+
# Solve for stain concentrations
3427+
stain_concentrations = np.linalg.solve(stain_correlation, density_projection).T
3428+
except np.linalg.LinAlgError:
3429+
# Fallback to pseudo-inverse if direct solve fails
3430+
stain_concentrations = np.linalg.lstsq(
3431+
stain_matrix.T,
3432+
optical_density,
3433+
rcond=regularization,
3434+
)[0].T
3435+
3436+
# Step 3: Apply concentration adjustments
3437+
if not augment_background:
3438+
# Only modify tissue regions
3439+
tissue_mask = get_tissue_mask(img).reshape(-1)
3440+
stain_concentrations[tissue_mask] = stain_concentrations[tissue_mask] * scale_factors + shift_values
3441+
else:
3442+
# Modify all pixels
3443+
stain_concentrations = stain_concentrations * scale_factors + shift_values
3444+
3445+
# Step 4: Reconstruct RGB image
3446+
optical_density_result = stain_concentrations @ stain_matrix
3447+
rgb_result = np.exp(-optical_density_result)
3448+
3449+
return rgb_result.reshape(img.shape)

0 commit comments

Comments
 (0)