Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ This package makes use of the following tools and libraries:

- **auto-verify** ([GitHub](https://github.com/ADA-research/auto-verify))
- For integrating verifiers [nnenum](https://github.com/stanleybak/nnenum), [AB-Crown](https://github.com/Verified-Intelligence/alpha-beta-CROWN), [VeriNet](https://github.com/vas-group-imperial/VeriNet), and [Oval-Bab](https://github.com/oval-group/oval-bab). Please refer to the auto-verify [documentation](https://ada-research.github.io/auto-verify/) for details about auto-verify.

- **foolbox** ([GitHub](https://github.com/bethgelab/foolbox))
- Rauber, J., Brendel, W., and Bethge, M., "Foolbox: A Python toolbox to benchmark the robustness of machine learning models," in *Reliable Machine Learning in the Wild Workshop, 34th International Conference on Machine Learning*, 2017. [Online]. Available: http://arxiv.org/abs/1707.04131
- Rauber, J., Zimmermann, R., Bethge, M., and Brendel, W., "Foolbox Native: Fast adversarial attacks to benchmark the robustness of machine learning models in PyTorch, TensorFlow, and JAX," *Journal of Open Source Software*, vol. 5, no. 53, p. 2607, 2020. [Online]. Available: https://doi.org/10.21105/joss.02607

We thank the authors and maintainers of these projects, as well as the authors and maintainers of the verifiers for their contributions to the robustness research community.

17 changes: 17 additions & 0 deletions ada_verona/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

# Dataset sampler classes
from .dataset_sampler.dataset_sampler import DatasetSampler
from .dataset_sampler.identity_sampler import IdentitySampler
from .dataset_sampler.predictions_based_sampler import PredictionsBasedSampler

# Epsilon value estimator classes
Expand Down Expand Up @@ -87,12 +88,22 @@
stacklevel=2,
)

# Check for foolbox availability
HAS_FOOLBOX = importlib.util.find_spec("foolbox") is not None
if not HAS_FOOLBOX:
warnings.warn(
"Foolbox not found. Some adversarial attack features will be limited. "
"To install: pip install foolbox",
stacklevel=2,
)


__all__ = [
"__version__",
"__author__",
"HAS_AUTOATTACK",
"HAS_AUTOVERIFY",
"HAS_FOOLBOX",
# Core abstract classes
"DatasetSampler",
"EpsilonValueEstimator",
Expand All @@ -114,6 +125,7 @@
"DataPoint",
# Dataset sampler classes
"PredictionsBasedSampler",
"IdentitySampler",
"PytorchExperimentDataset",
"ImageFileDataset",
# Epsilon value estimator classes
Expand Down Expand Up @@ -146,3 +158,8 @@
"parse_counter_example_label",
]
)

if HAS_FOOLBOX:
foolbox_attack_module = importlib.import_module(".verification_module.attacks.foolbox_attack", __package__)
FoolboxAttack = foolbox_attack_module.FoolboxAttack
__all__.extend(["FoolboxAttack"])
95 changes: 95 additions & 0 deletions ada_verona/verification_module/attacks/foolbox_attack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2025 ADA Reseach Group and VERONA council. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import foolbox
from torch import Tensor, nn

from ada_verona.verification_module.attacks.attack import Attack


class FoolboxAttack(Attack):
"""
A wrapper for Foolbox adversarial attacks.
Requires foolbox to be installed: pip install foolbox

Attributes:
attack_cls (class): The Foolbox attack class to use.
kwargs (dict): Arguments to pass to the attack constructor.
"""

def __init__(self, attack_cls, bounds=(0, 1), **kwargs) -> None:
"""
Initialize the FoolboxAttack wrapper.

Args:
attack_cls (class): The Foolbox attack class (e.g., foolbox.attacks.LinfPGD).
bounds (tuple, optional): The bounds of the input data. Defaults to (0, 1).
**kwargs: Arguments to be passed to the attack constructor (e.g., steps=40).
"""
super().__init__()
self.attack_cls = attack_cls
self.bounds = bounds
self.kwargs = kwargs
self.name = f"FoolboxAttack ({attack_cls.__name__}, bounds={bounds}, {kwargs})"

def execute(self, model: nn.Module, data: Tensor, target: Tensor, epsilon: float) -> Tensor:
"""
Execute the Foolbox attack on the given model and data.

Args:
model (nn.Module): The model to attack.
data (Tensor): The input data to perturb.
target (Tensor): The target labels for the data.
epsilon (float): The perturbation magnitude.

Returns:
Tensor: The perturbed data.
"""
fmodel = foolbox.PyTorchModel(model, bounds=self.bounds)

attack = self.attack_cls(**self.kwargs)

# Ensure data has batch dimension (Foolbox requires batch dimension)
# Data should be (batch_size, channels, height, width) or (batch_size, features)
# Foolbox expects at least 2D tensors: (batch_size, ...)
if data.dim() == 0:
# Scalar, add batch dimension: (1,)
data = data.unsqueeze(0)
elif data.dim() == 1:
# 1D tensor, add batch dimension: (1, features)
data = data.unsqueeze(0)
elif data.dim() == 3:
# 3D tensor (C, H, W), add batch dimension: (1, C, H, W)
data = data.unsqueeze(0)
# If data is already 4D (B, C, H, W) or 2D (B, features), keep as is
# But verify it has a batch dimension
if data.dim() >= 2 and data.shape[0] == 0:
raise ValueError(f"Data tensor has invalid batch size: {data.shape}")

# Ensure target has batch dimension
# Target should be 1D with shape (batch_size,) for a single sample: (1,)
if target.dim() == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happends when its a targeted attack?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

targeted attacks are currently not supported - the target here is the correct label

AttackEstimationModule.verify() currently always passes the true label (verification_context.data_point.label) as the "target"

# Scalar target, add batch dimension
target = target.unsqueeze(0)
elif target.dim() == 1:
# Already 1D, should be fine (typically shape (1,) for single sample)
# But ensure it's not empty
if target.shape[0] == 0:
raise ValueError("Target tensor cannot be empty")
# If target is already correct shape, keep as is

_, clipped_advs, _ = attack(fmodel, data, target, epsilons=epsilon)

return clipped_advs
1 change: 1 addition & 0 deletions docs/how-to-guides.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ VERONA implements the following adversarial attack methods:
- **Fast Gradient Sign Method (FGSM)** - [Goodfellow et al., 2015](https://arxiv.org/abs/1412.6572)
- **Projected Gradient Descent (PGD)** - [Madry et al., 2018](https://arxiv.org/abs/1706.06083)
- **AutoAttack** - [Croce and Hein, 2020](https://proceedings.mlr.press/v119/croce20b.html)
- All attacks from [foolbox](https://github.com/bethgelab/foolbox/tree/master/foolbox/attacks) through the [`FoolboxAttack`](../ada_verona/verification_module/attacks/foolbox_attack.py) class

### Optional: AutoAttack Installation

Expand Down
84 changes: 84 additions & 0 deletions examples/scripts/create_robustness_dist_foolbox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2025 ADA Reseach Group and VERONA council. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import logging
from pathlib import Path

import numpy as np
from foolbox.attacks import LinfPGD

import ada_verona.util.logger as logger
from ada_verona.database.dataset.image_file_dataset import ImageFileDataset
from ada_verona.database.experiment_repository import ExperimentRepository
from ada_verona.dataset_sampler.predictions_based_sampler import PredictionsBasedSampler
from ada_verona.epsilon_value_estimator.binary_search_epsilon_value_estimator import (
BinarySearchEpsilonValueEstimator,
)
from ada_verona.verification_module.attack_estimation_module import AttackEstimationModule
from ada_verona.verification_module.attacks.foolbox_attack import FoolboxAttack
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add the check whether foolbox is installed like we have for the other scripts

from ada_verona.verification_module.property_generator.one2any_property_generator import (
One2AnyPropertyGenerator,
)

logger.setup_logging(level=logging.INFO)

experiment_name = "foolbox_pgd"
timeout = 600
experiment_repository_path = Path("../example_experiment/results_foolbox")
network_folder = Path("../example_experiment/data/networks")
image_folder = Path("../example_experiment/data/images")
image_label_file = Path("../example_experiment/data/image_labels.csv")
epsilon_list = np.arange(0.00, 0.4, 0.0039)

dataset = ImageFileDataset(image_folder=image_folder, label_file=image_label_file)

file_database = ExperimentRepository(base_path=experiment_repository_path, network_folder=network_folder)

file_database.initialize_new_experiment(experiment_name)

file_database.save_configuration(
dict(
experiment_name=experiment_name,
experiment_repository_path=str(experiment_repository_path),
network_folder=str(network_folder),
dataset=str(dataset),
timeout=timeout,
epsilon_list=[str(x) for x in epsilon_list],
)
)

property_generator = One2AnyPropertyGenerator()
verifier = AttackEstimationModule(attack=FoolboxAttack(LinfPGD, bounds=(0, 1), steps=10))

epsilon_value_estimator = BinarySearchEpsilonValueEstimator(epsilon_value_list=epsilon_list.copy(), verifier=verifier)
dataset_sampler = PredictionsBasedSampler(sample_correct_predictions=True)

network_list = file_database.get_network_list()

print(f"Found {len(network_list)} networks.")

for network in network_list:
print(f"Processing network: {network.name}")
sampled_data = dataset_sampler.sample(network, dataset)
print(f"Sampled {len(sampled_data)} data points.")

for i, data_point in enumerate(sampled_data):
print(f"Verifying data point {i}...")
verification_context = file_database.create_verification_context(network, data_point, property_generator)
epsilon_value_result = epsilon_value_estimator.compute_epsilon_value(verification_context)
print(f"Result: {epsilon_value_result}")
file_database.save_result(epsilon_value_result)

print("Done.")
90 changes: 90 additions & 0 deletions examples/scripts/create_robustness_dist_foolbox_cw.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potentially leave out of mandatory dependencies

Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2025 ADA Reseach Group and VERONA council. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import logging
from pathlib import Path

import numpy as np
from foolbox.attacks import L2CarliniWagnerAttack

import ada_verona.util.logger as logger
from ada_verona.database.dataset.image_file_dataset import ImageFileDataset
from ada_verona.database.experiment_repository import ExperimentRepository
from ada_verona.dataset_sampler.predictions_based_sampler import PredictionsBasedSampler
from ada_verona.epsilon_value_estimator.binary_search_epsilon_value_estimator import (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need two example scripts for such a small different?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will remove one

BinarySearchEpsilonValueEstimator,
)
from ada_verona.verification_module.attack_estimation_module import AttackEstimationModule
from ada_verona.verification_module.attacks.foolbox_attack import FoolboxAttack
from ada_verona.verification_module.property_generator.one2any_property_generator import (
One2AnyPropertyGenerator,
)

logger.setup_logging(level=logging.INFO)

experiment_name = "foolbox_cw"
timeout = 600
experiment_repository_path = Path("../example_experiment/results_foolbox_cw")
network_folder = Path("../example_experiment/data/networks")
image_folder = Path("../example_experiment/data/images")
image_label_file = Path("../example_experiment/data/image_labels.csv")

epsilon_list = np.arange(0.00, 0.4, 0.0039)

dataset = ImageFileDataset(image_folder=image_folder, label_file=image_label_file)

file_database = ExperimentRepository(base_path=experiment_repository_path, network_folder=network_folder)

file_database.initialize_new_experiment(experiment_name)

file_database.save_configuration(
dict(
experiment_name=experiment_name,
experiment_repository_path=str(experiment_repository_path),
network_folder=str(network_folder),
dataset=str(dataset),
timeout=timeout,
epsilon_list=[str(x) for x in epsilon_list],
)
)

property_generator = One2AnyPropertyGenerator()
verifier = AttackEstimationModule(attack=FoolboxAttack(L2CarliniWagnerAttack, bounds=(0, 1), steps=100))

epsilon_value_estimator = BinarySearchEpsilonValueEstimator(epsilon_value_list=epsilon_list.copy(), verifier=verifier)
dataset_sampler = PredictionsBasedSampler(sample_correct_predictions=True)

network_list = file_database.get_network_list()

print(f"Found {len(network_list)} networks.")

for network in network_list:
print(f"Processing network: {network.name}")
sampled_data = dataset_sampler.sample(network, dataset)
print(f"Sampled {len(sampled_data)} data points.")

for i, data_point in enumerate(sampled_data):
if i >= 1:
break

print(f"Verifying data point {i}...")
verification_context = file_database.create_verification_context(network, data_point, property_generator)

epsilon_value_result = epsilon_value_estimator.compute_epsilon_value(verification_context)

print(f"Result: {epsilon_value_result}")
file_database.save_result(epsilon_value_result)

print("Done.")
1 change: 1 addition & 0 deletions pyproject.toml
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its not mandatory though, treat like AutoAttack and AutoVerify\

Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies = [
"onnx>=1.14.0",
"onnxruntime>=1.14.1",
"onnx2torch>=1.5.14",
"foolbox>=3.3.4",
"pandas>=2.0.1",
"PyYAML>=6.0.1",
"result>=0.9.0",
Expand Down
16 changes: 14 additions & 2 deletions tests/test_verification_module/attacks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# limitations under the License.
# ==============================================================================

import foolbox as fb
import pytest
import torch
from torch import nn

from ada_verona.verification_module.attacks.auto_attack_wrapper import AutoAttackWrapper
from ada_verona.verification_module.attacks.fgsm_attack import FGSMAttack
from ada_verona.verification_module.attacks.foolbox_attack import FoolboxAttack
from ada_verona.verification_module.attacks.pgd_attack import PGDAttack


Expand All @@ -34,22 +36,32 @@ def forward(self, x):

return SimpleModel()


@pytest.fixture
def data():
return torch.randn(1, 10)


@pytest.fixture
def target():
return torch.tensor([1])


@pytest.fixture
def attack_wrapper():
return AutoAttackWrapper(device="cpu", norm="Linf", version="standard", verbose=False)


@pytest.fixture
def pgd_attack():
return PGDAttack(number_iterations=10, step_size=0.01, randomise=True)
return PGDAttack(number_iterations=10, step_size=0.01, randomise=True)


@pytest.fixture
def fgsm_attack():
return FGSMAttack()
return FGSMAttack()


@pytest.fixture
def foolbox_attack():
return FoolboxAttack(attack_cls=fb.attacks.LinfFastGradientAttack)
Loading
Loading