You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
if center_crop_size is None:
# scale center crop size proportional to image size
height, width = image_size
center_crop_size = (int(height * (224 / 256)), int(width * (224 / 256)))
under def configure_pre_processor.
Using my previous image_size of (768, 1536), it changes it to (672, 1344), which checks out math-wise. Setting the image_size to (672, 1344) further downscales it, etc. I also noticed that other models (Padim, STFPM, Reverse Distillation) do not have a configure_pre_processor function. They instead use anomalib_module.py. I assume this is intended since I am not very knowledgeable about Anomalib.
The heat in the heatmaps also become offset due to this. Furthermore, if you place an anomaly at the bottom of the image (or other sides I assume), Patchcore does not pick this up at all due to the offset. It does not reflect in the anomaly score and shows no heat in the heatmap. Setting a center_crop also fixes this. I would show images of this, but I am currently doing my bachelor's thesis and the dataset has been provided by a company that may not accept me publishing images from their dataset here.
On a sidenote, should I be using 224 x 224 images for Patchcore, or does 256 x 256 work fine? I noticed that if you don't set a resolution it automatically downscales to 224 x 224, but for other models it downscales to 256 x 256.
Dataset
Folder
Model
PatchCore
Steps to reproduce the behavior
Some code here is irrelevant. Tested with and without tiling & visualizer.
import torch
from anomalib.data import Folder
from anomalib.engine import Engine
from anomalib.models import Patchcore
from anomalib.callbacks import TilerConfigurationCallback
from anomalib.visualization import ImageVisualizer
tiler_config_callback = TilerConfigurationCallback(enable=True, tile_size=[256, 256], stride=64)
pre_processor = Patchcore.configure_pre_processor(image_size=(768, 1536)) # Height, Width
visualizer = ImageVisualizer(
fields=["image", "anomaly_map"], # Only visualize image and anomaly map (no masks)
overlay_fields=[("image", ["anomaly_map"])], # Only overlay anomaly map on image
field_size=(1536, 768), # Set size (Width, Height)
output_dir="results/Patchcore/visualizations", # Set output directory
)
datamodule = Folder(
root="./datasets/dataset",
name="dataset",
normal_dir="ok",
normal_test_dir="ok_test",
abnormal_dir="nok",
val_split_mode='same_as_test',
train_batch_size=1,
eval_batch_size=1,
num_workers=11,
)
datamodule.setup()
model = Patchcore(
backbone="wide_resnet50_2",
layers=["layer2", "layer3"],
pre_trained=True,
coreset_sampling_ratio=0.1,
num_neighbors=9,
pre_processor=pre_processor,
visualizer=visualizer
)
engine = Engine(callbacks=[tiler_config_callback])
engine.fit(model=model, datamodule=datamodule)
then use engine.test or engine.predict...
Image resolution can be seen with print(prediction.image.shape)
OS information
OS: Windows 11
Python version: 3.11.9
Anomalib version: 2.0
PyTorch version: 2.6.0
CUDA/cuDNN version: 12.4
GPU models and configuration: 1x GeForce RTX 3060 Ti
Any other relevant information: I use Jupyter Notebook
Expected behavior
The images would maintain the specified resolution.
Screenshots
No response
Pip/GitHub
pip
What version/branch did you use?
No response
Configuration YAML
N/A
Logs
N/A
Code of Conduct
I agree to follow this project's Code of Conduct
The text was updated successfully, but these errors were encountered:
On a sidenote, should I be using 224 x 224 images for Patchcore, or does 256 x 256 work fine? I noticed that if you don't set a resolution it automatically downscales to 224 x 224, but for other models it downscales to 256 x 256.
I just realized this happens because of the same code. The resolution defaults to 256 x 256, but then gets downscaled to 224 x 224.
Hello, you do not need to use center crop transform for PatchCore (256x256 works fine), I believe it was added to reproduce paper settings, which are not necessarily applicable in practical applications.
Describe the bug
When training a Patchcore model and setting the image size with
it automatically changes the resolution to a lower one than specified. This does not happen if you include center_crop, such as
I tracked this down to the file:
https://github.com/openvinotoolkit/anomalib/blob/v2.0.0/src/anomalib/models/image/patchcore/lightning_model.py#L188
Specifically the code:
under
def configure_pre_processor
.Using my previous image_size of (768, 1536), it changes it to (672, 1344), which checks out math-wise. Setting the image_size to (672, 1344) further downscales it, etc. I also noticed that other models (Padim, STFPM, Reverse Distillation) do not have a
configure_pre_processor
function. They instead useanomalib_module.py
. I assume this is intended since I am not very knowledgeable about Anomalib.The heat in the heatmaps also become offset due to this. Furthermore, if you place an anomaly at the bottom of the image (or other sides I assume), Patchcore does not pick this up at all due to the offset. It does not reflect in the anomaly score and shows no heat in the heatmap. Setting a center_crop also fixes this. I would show images of this, but I am currently doing my bachelor's thesis and the dataset has been provided by a company that may not accept me publishing images from their dataset here.
On a sidenote, should I be using 224 x 224 images for Patchcore, or does 256 x 256 work fine? I noticed that if you don't set a resolution it automatically downscales to 224 x 224, but for other models it downscales to 256 x 256.
Dataset
Folder
Model
PatchCore
Steps to reproduce the behavior
Some code here is irrelevant. Tested with and without tiling & visualizer.
then use
engine.test
orengine.predict
...Image resolution can be seen with
print(prediction.image.shape)
OS information
OS: Windows 11
Python version: 3.11.9
Anomalib version: 2.0
PyTorch version: 2.6.0
CUDA/cuDNN version: 12.4
GPU models and configuration: 1x GeForce RTX 3060 Ti
Any other relevant information: I use Jupyter Notebook
Expected behavior
The images would maintain the specified resolution.
Screenshots
No response
Pip/GitHub
pip
What version/branch did you use?
No response
Configuration YAML
N/A
Logs
Code of Conduct
The text was updated successfully, but these errors were encountered: