diff --git a/src/torchio/transforms/__init__.py b/src/torchio/transforms/__init__.py index 5ecaf147..518d25b8 100644 --- a/src/torchio/transforms/__init__.py +++ b/src/torchio/transforms/__init__.py @@ -8,6 +8,7 @@ from .augmentation.intensity import Motion from .augmentation.intensity import Noise from .augmentation.intensity import RandomBiasField +from .augmentation.intensity import RandomBiasFieldDenoise from .augmentation.intensity import RandomBlur from .augmentation.intensity import RandomGamma from .augmentation.intensity import RandomGhosting @@ -120,4 +121,5 @@ 'RemoveLabels', 'SequentialLabels', 'KeepLargestComponent', + 'RandomBiasFieldDenoise', ] diff --git a/src/torchio/transforms/augmentation/intensity/__init__.py b/src/torchio/transforms/augmentation/intensity/__init__.py index 343a8359..8fb65e59 100644 --- a/src/torchio/transforms/augmentation/intensity/__init__.py +++ b/src/torchio/transforms/augmentation/intensity/__init__.py @@ -1,5 +1,6 @@ from .random_bias_field import BiasField from .random_bias_field import RandomBiasField +from .random_biasfield_denoise import RandomBiasFieldDenoise from .random_blur import Blur from .random_blur import RandomBlur from .random_gamma import Gamma @@ -36,4 +37,5 @@ 'BiasField', 'RandomLabelsToImage', 'LabelsToImage', + 'RandomBiasFieldDenoise', ] diff --git a/src/torchio/transforms/augmentation/intensity/random_biasfield_denoise.py b/src/torchio/transforms/augmentation/intensity/random_biasfield_denoise.py new file mode 100644 index 00000000..d87045b1 --- /dev/null +++ b/src/torchio/transforms/augmentation/intensity/random_biasfield_denoise.py @@ -0,0 +1,34 @@ +from torchio.data.subject import Subject + +from ...transform import Transform + + +class RandomBiasFieldDenoise(Transform): + """ + Simple placeholder transform that simulates denoising after bias field + correction by blending voxel intensities toward the mean value. + + Parameters: + noise_reduction_factor (float): Strength of denoising (0-1). + """ + + def __init__(self, noise_reduction_factor: float = 0.1, **kwargs): + super().__init__(**kwargs) + self.noise_reduction_factor = noise_reduction_factor + + def apply_transform(self, subject: Subject) -> Subject: + for _, image in subject.get_images_dict(intensity_only=True).items(): + tensor = image.data.float() + + # Basic denoising by shifting toward mean intensity + mean_val = tensor.mean() + tensor = (tensor * (1 - self.noise_reduction_factor)) + ( + mean_val * self.noise_reduction_factor + ) + + image.set_data(tensor) + + return subject + + def __repr__(self): + return f'{self.__class__.__name__}(noise_reduction_factor={self.noise_reduction_factor})' diff --git a/visualize_denoise.py b/visualize_denoise.py new file mode 100644 index 00000000..8187b4cd --- /dev/null +++ b/visualize_denoise.py @@ -0,0 +1,34 @@ +import matplotlib.pyplot as plt + +import torchio as tio + +subject = tio.datasets.Colin27() + +# Get full 3D volume +image_3d = subject.t1.data.squeeze().numpy() + +# Pick center slice (axial) +slice_idx = image_3d.shape[2] // 2 +image = image_3d[:, :, slice_idx] + +# ---- Apply custom transform ---- +transform = tio.RandomBiasFieldDenoise(noise_reduction_factor=0.3) +denoised_subject = transform(subject) + +# Extract denoised slice +denoised_3d = denoised_subject.t1.data.squeeze().numpy() +denoised_image = denoised_3d[:, :, slice_idx] + +# ---- Plot ---- +fig, axes = plt.subplots(1, 2, figsize=(10, 5)) + +axes[0].imshow(image, cmap='gray') +axes[0].set_title('Original') +axes[0].axis('off') + +axes[1].imshow(denoised_image, cmap='gray') +axes[1].set_title('After Denoise') +axes[1].axis('off') + +plt.tight_layout() +plt.show()