Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions CTRAIN/model_wrappers/crown_ibp_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class CrownIBPModelWrapper(CTRAINWrapper):
"""

def __init__(self, model, input_shape, eps, num_epochs, train_eps_factor=1, optimizer_func=torch.optim.Adam, lr=0.0005, warm_up_epochs=1, ramp_up_epochs=70,
lr_decay_factor=.2, lr_decay_milestones=(80, 90), gradient_clip=10, l1_reg_weight=0.000001,
lr_scheduler_func=torch.optim.lr_scheduler.MultiStepLR, lr_decay_kwargs=dict(milestones=(80, 90), gamma=0.2), gradient_clip=10, l1_reg_weight=0.000001,
shi_reg_weight=.5, shi_reg_decay=True, start_beta=1, end_beta=0,
loss_fusion=True, checkpoint_save_path=None, checkpoint_save_interval=10,
bound_opts=dict(conv_mode='patches', relu='adaptive'), device=torch.device('cuda')):
Expand Down Expand Up @@ -45,14 +45,12 @@ def __init__(self, model, input_shape, eps, num_epochs, train_eps_factor=1, opti
bound_opts (dict): Options for bounding according to the auto_LiRPA documentation.
device (torch.device): Device to run the training on.
"""
super().__init__(model, eps, input_shape, train_eps_factor, lr, optimizer_func, bound_opts, device, checkpoint_save_path=checkpoint_save_path, checkpoint_save_interval=checkpoint_save_interval)
super().__init__(model, eps, input_shape, train_eps_factor, lr, optimizer_func, lr_scheduler_func, lr_decay_kwargs, bound_opts, device, checkpoint_save_path=checkpoint_save_path, checkpoint_save_interval=checkpoint_save_interval)
self.cert_train_method = 'crown_ibp'
self.num_epochs = num_epochs
self.lr = lr
self.warm_up_epochs = warm_up_epochs
self.ramp_up_epochs = ramp_up_epochs
self.lr_decay_factor = lr_decay_factor
self.lr_decay_milestones = lr_decay_milestones
self.gradient_clip = gradient_clip
self.l1_reg_weight = l1_reg_weight
self.shi_reg_weight = shi_reg_weight
Expand Down Expand Up @@ -132,8 +130,7 @@ def train_model(self, train_loader, val_loader=None, start_epoch=0, end_epoch=No
eps_schedule=(self.warm_up_epochs, self.ramp_up_epochs),
eps_scheduler_args={'start_kappa': self.start_kappa, 'end_kappa': self.end_kappa, 'start_beta': self.start_beta, 'end_beta': self.end_beta},
optimizer=self.optimizer if not self.loss_fusion else self.loss_fusion_optimizer,
lr_decay_schedule=self.lr_decay_milestones,
lr_decay_factor=self.lr_decay_factor,
lr_scheduler=self.lr_scheduler,
n_classes=self.n_classes,
loss_fusion=self.loss_fusion,
gradient_clip=self.gradient_clip,
Expand Down
6 changes: 5 additions & 1 deletion CTRAIN/model_wrappers/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ class CTRAINWrapper(nn.Module):
"""
Wrapper base class for certifiably training models.
"""
def __init__(self, model: nn.Module, eps:float, input_shape: tuple, train_eps_factor=1, lr=0.0005, optimizer_func=torch.optim.Adam, bound_opts=dict(conv_mode='patches', relu='adaptive'), device='cuda', checkpoint_save_path=None, checkpoint_save_interval=10):
def __init__(self, model: nn.Module, eps:float, input_shape: tuple, train_eps_factor=1, lr=0.0005, optimizer_func=torch.optim.Adam,
lr_scheduler_func=torch.optim.lr_scheduler.MultiStepLR, lr_decay_kwargs=dict(milestones=(80, 90), gamma=0.2), bound_opts=dict(conv_mode='patches', relu='adaptive'), device='cuda', checkpoint_save_path=None, checkpoint_save_interval=10):
"""
Initialize the CTRAINWrapper Base Class.

Expand Down Expand Up @@ -73,6 +74,9 @@ def __init__(self, model: nn.Module, eps:float, input_shape: tuple, train_eps_fa

self.optimizer_func = optimizer_func
self.optimizer = optimizer_func(self.bounded_model.parameters(), lr=lr)

self.lr_scheduler_func = lr_scheduler_func
self.lr_scheduler = self.lr_scheduler_func(self.optimizer, **lr_decay_kwargs)

self.epoch = 0

Expand Down
9 changes: 3 additions & 6 deletions CTRAIN/model_wrappers/mtl_ibp_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class MTLIBPModelWrapper(CTRAINWrapper):
"""

def __init__(self, model, input_shape, eps, num_epochs, train_eps_factor=1, optimizer_func=torch.optim.Adam, lr=0.0005, warm_up_epochs=1, ramp_up_epochs=70,
lr_decay_factor=.2, lr_decay_milestones=(80, 90), gradient_clip=10, l1_reg_weight=0.000001,
lr_scheduler_func=torch.optim.lr_scheduler.MultiStepLR, lr_decay_kwargs=dict(milestones=(80, 90)), gradient_clip=10, l1_reg_weight=0.000001,
shi_reg_weight=.5, shi_reg_decay=True, pgd_steps=1,
pgd_alpha=10, pgd_restarts=1, pgd_early_stopping=False, pgd_alpha_decay_factor=.1,
pgd_decay_milestones=(), pgd_eps_factor=1, mtl_ibp_alpha=0.5, checkpoint_save_path=None, checkpoint_save_interval=10,
Expand Down Expand Up @@ -49,14 +49,12 @@ def __init__(self, model, input_shape, eps, num_epochs, train_eps_factor=1, opti
bound_opts (dict): Options for bounding according to the auto_LiRPA documentation.
device (torch.device): Device to run the training on.
"""
super().__init__(model, eps, input_shape, train_eps_factor, lr, optimizer_func, bound_opts, device, checkpoint_save_path=checkpoint_save_path, checkpoint_save_interval=checkpoint_save_interval)
super().__init__(model, eps, input_shape, train_eps_factor, lr, optimizer_func, lr_scheduler_func, lr_decay_kwargs, bound_opts, device, checkpoint_save_path=checkpoint_save_path, checkpoint_save_interval=checkpoint_save_interval)
self.cert_train_method = 'mtl_ibp'
self.num_epochs = num_epochs
self.lr = lr
self.warm_up_epochs = warm_up_epochs
self.ramp_up_epochs = ramp_up_epochs
self.lr_decay_factor = lr_decay_factor
self.lr_decay_milestones = lr_decay_milestones
self.gradient_clip = gradient_clip
self.l1_reg_weight = l1_reg_weight
self.shi_reg_weight = shi_reg_weight
Expand Down Expand Up @@ -100,8 +98,7 @@ def train_model(self, train_loader, val_loader=None, start_epoch=0, end_epoch=No
eps_std=eps_std,
eps_schedule=(self.warm_up_epochs, self.ramp_up_epochs),
optimizer=self.optimizer,
lr_decay_schedule=self.lr_decay_milestones,
lr_decay_factor=self.lr_decay_factor,
lr_scheduler=self.lr_scheduler,
n_classes=self.n_classes,
gradient_clip=self.gradient_clip,
l1_regularisation_weight=self.l1_reg_weight,
Expand Down
9 changes: 3 additions & 6 deletions CTRAIN/model_wrappers/sabr_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class SABRModelWrapper(CTRAINWrapper):
"""

def __init__(self, model, input_shape, eps, num_epochs, train_eps_factor=1, optimizer_func=torch.optim.Adam, lr=0.0005, warm_up_epochs=1, ramp_up_epochs=70,
lr_decay_factor=.2, lr_decay_milestones=(80, 90), gradient_clip=10, l1_reg_weight=0.000001,
lr_scheduler_func=torch.optim.lr_scheduler.MultiStepLR, lr_decay_kwargs=dict(milestones=(80, 90), gamma=0.2), gradient_clip=10, l1_reg_weight=0.000001,
shi_reg_weight=.5, shi_reg_decay=True, sabr_subselection_ratio=.2, pgd_steps=8,
pgd_alpha=0.5, pgd_restarts=1, pgd_early_stopping=False, pgd_alpha_decay_factor=.1,
pgd_decay_milestones=(4,7), pgd_eps_factor=1, checkpoint_save_path=None, checkpoint_save_interval=10,
Expand Down Expand Up @@ -50,14 +50,12 @@ def __init__(self, model, input_shape, eps, num_epochs, train_eps_factor=1, opti
bound_opts (dict): Options for bounding according to the auto_LiRPA documentation.
device (torch.device): Device to run the training on.
"""
super().__init__(model, eps, input_shape, train_eps_factor, lr, optimizer_func, bound_opts, device, checkpoint_save_path=checkpoint_save_path, checkpoint_save_interval=checkpoint_save_interval)
super().__init__(model, eps, input_shape, train_eps_factor, lr, optimizer_func, lr_scheduler_func, lr_decay_kwargs, bound_opts, device, checkpoint_save_path=checkpoint_save_path, checkpoint_save_interval=checkpoint_save_interval)
self.cert_train_method = 'sabr'
self.num_epochs = num_epochs
self.lr = lr
self.warm_up_epochs = warm_up_epochs
self.ramp_up_epochs = ramp_up_epochs
self.lr_decay_factor = lr_decay_factor
self.lr_decay_milestones = lr_decay_milestones
self.gradient_clip = gradient_clip
self.l1_reg_weight = l1_reg_weight
self.shi_reg_weight = shi_reg_weight
Expand Down Expand Up @@ -100,8 +98,7 @@ def train_model(self, train_loader, val_loader=None, start_epoch=0, end_epoch=No
eps_schedule=(self.warm_up_epochs, self.ramp_up_epochs),
eps_scheduler_args={},
optimizer=self.optimizer,
lr_decay_schedule=self.lr_decay_milestones,
lr_decay_factor=self.lr_decay_factor,
lr_scheduler=self.lr_scheduler,
n_classes=self.n_classes,
gradient_clip=self.gradient_clip,
l1_regularisation_weight=self.l1_reg_weight,
Expand Down
13 changes: 4 additions & 9 deletions CTRAIN/model_wrappers/shi_ibp_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@ class ShiIBPModelWrapper(CTRAINWrapper):
Wrapper class for training models using SHI-IBP method. For details, see Shi et al. (2021) Fast certified robust training with short warmup. https://proceedings.neurips.cc/paper/2021/file/988f9153ac4fd966ea302dd9ab9bae15-Paper.pdf
"""

def __init__(self, model, input_shape, eps, num_epochs, train_eps_factor=1, optimizer_func=torch.optim.Adam, lr=0.0005, warm_up_epochs=1, ramp_up_epochs=70,
lr_decay_factor=.2, lr_decay_milestones=(80, 90), gradient_clip=10, l1_reg_weight=0.000001,
shi_reg_weight=.5, shi_reg_decay=True, checkpoint_save_path=None, checkpoint_save_interval=10,
bound_opts=dict(conv_mode='patches', relu='adaptive'), device=torch.device('cuda')):
def __init__(self, model, input_shape, eps, num_epochs, train_eps_factor=1, optimizer_func=torch.optim.Adam, lr=0.0005, warm_up_epochs=1, ramp_up_epochs=70,lr_scheduler_func=torch.optim.lr_scheduler.MultiStepLR, lr_decay_kwargs=dict(milestones=(80, 90), gamma=0.2), gradient_clip=10, l1_reg_weight=0.000001, shi_reg_weight=.5, shi_reg_decay=True, checkpoint_save_path=None, checkpoint_save_interval=10,
bound_opts=dict(conv_mode='patches', relu='adaptive'), device=torch.device('cuda')):
"""
Initializes the ShiIBPModelWrapper.

Expand All @@ -40,14 +38,12 @@ def __init__(self, model, input_shape, eps, num_epochs, train_eps_factor=1, opti
bound_opts (dict): Options for bounding according to the auto_LiRPA documentation.
device (torch.device): Device to run the training on.
"""
super().__init__(model, eps, input_shape, train_eps_factor, lr, optimizer_func, bound_opts, device, checkpoint_save_path=checkpoint_save_path, checkpoint_save_interval=checkpoint_save_interval)
super().__init__(model, eps, input_shape, train_eps_factor, lr, optimizer_func, lr_scheduler_func, lr_decay_kwargs, bound_opts, device, checkpoint_save_path=checkpoint_save_path, checkpoint_save_interval=checkpoint_save_interval)
self.cert_train_method = 'shi'
self.num_epochs = num_epochs
self.lr = lr
self.warm_up_epochs = warm_up_epochs
self.ramp_up_epochs = ramp_up_epochs
self.lr_decay_factor = lr_decay_factor
self.lr_decay_milestones = lr_decay_milestones
self.gradient_clip = gradient_clip
self.l1_reg_weight = l1_reg_weight
self.shi_reg_weight = shi_reg_weight
Expand Down Expand Up @@ -84,8 +80,7 @@ def train_model(self, train_loader, val_loader=None, start_epoch=0, end_epoch=No
eps_schedule=(self.warm_up_epochs, self.ramp_up_epochs),
eps_scheduler_args={'start_kappa': self.start_kappa, 'end_kappa': self.end_kappa},
optimizer=self.optimizer,
lr_decay_schedule=self.lr_decay_milestones,
lr_decay_factor=self.lr_decay_factor,
lr_scheduler=self.lr_scheduler,
n_classes=self.n_classes,
gradient_clip=self.gradient_clip,
l1_regularisation_weight=self.l1_reg_weight,
Expand Down
9 changes: 3 additions & 6 deletions CTRAIN/model_wrappers/staps_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class STAPSModelWrapper(CTRAINWrapper):
"""

def __init__(self, model, input_shape, eps, num_epochs, train_eps_factor=1, optimizer_func=torch.optim.Adam, lr=0.0005, warm_up_epochs=1, ramp_up_epochs=70,
lr_decay_factor=.2, lr_decay_milestones=(80, 90), gradient_clip=10, l1_reg_weight=0.000001,
lr_scheduler_func=torch.optim.lr_scheduler.MultiStepLR, lr_decay_kwargs=dict(milestones=(80, 90), gamma=0.2), gradient_clip=10, l1_reg_weight=0.000001,
shi_reg_weight=.5, shi_reg_decay=True, pgd_steps=8,
pgd_alpha=0.5, pgd_restarts=1, pgd_early_stopping=False, pgd_alpha_decay_factor=.1,
pgd_decay_steps=(4,7), sabr_pgd_steps=8, sabr_pgd_alpha=0.5, sabr_pgd_restarts=1,
Expand Down Expand Up @@ -67,14 +67,12 @@ def __init__(self, model, input_shape, eps, num_epochs, train_eps_factor=1, opti
bound_opts (dict): Options for bounding according to the auto_LiRPA documentation.
device (torch.device): Device to run the training on.
"""
super().__init__(model, eps, input_shape, train_eps_factor, lr, optimizer_func, bound_opts, device, checkpoint_save_path=checkpoint_save_path, checkpoint_save_interval=checkpoint_save_interval)
super().__init__(model, eps, input_shape, train_eps_factor, lr, optimizer_func, lr_scheduler_func, lr_decay_kwargs, bound_opts, device, checkpoint_save_path=checkpoint_save_path, checkpoint_save_interval=checkpoint_save_interval)
self.cert_train_method = 'staps'
self.num_epochs = num_epochs
self.lr = lr
self.warm_up_epochs = warm_up_epochs
self.ramp_up_epochs = ramp_up_epochs
self.lr_decay_factor = lr_decay_factor
self.lr_decay_milestones = lr_decay_milestones
self.gradient_clip = gradient_clip
self.l1_reg_weight = l1_reg_weight
self.shi_reg_weight = shi_reg_weight
Expand Down Expand Up @@ -150,8 +148,7 @@ def train_model(self, train_loader, val_loader=None, start_epoch=0, end_epoch=No
eps_schedule=(self.warm_up_epochs, self.ramp_up_epochs),
eps_scheduler_args={},
optimizer=self.optimizer,
lr_decay_schedule=self.lr_decay_milestones,
lr_decay_factor=self.lr_decay_factor,
lr_scheduler=self.lr_scheduler,
n_classes=self.n_classes,
gradient_clip=self.gradient_clip,
l1_regularisation_weight=self.l1_reg_weight,
Expand Down
9 changes: 3 additions & 6 deletions CTRAIN/model_wrappers/taps_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class TAPSModelWrapper(CTRAINWrapper):
"""

def __init__(self, model, input_shape, eps, num_epochs, train_eps_factor=1, optimizer_func=torch.optim.Adam, lr=0.0005, warm_up_epochs=1, ramp_up_epochs=70,
lr_decay_factor=.2, lr_decay_milestones=(80, 90), gradient_clip=10, l1_reg_weight=0.000001,
lr_scheduler_func=torch.optim.lr_scheduler.MultiStepLR, lr_decay_kwargs=dict(milestones=(80, 90), gamma=0.2), gradient_clip=10, l1_reg_weight=0.000001,
shi_reg_weight=.5, shi_reg_decay=True, pgd_steps=8,
pgd_alpha=0.5, pgd_restarts=1, pgd_early_stopping=False, pgd_alpha_decay_factor=.1,
pgd_decay_steps=(4,7), block_sizes=None, gradient_expansion_alpha=5,
Expand Down Expand Up @@ -58,14 +58,12 @@ def __init__(self, model, input_shape, eps, num_epochs, train_eps_factor=1, opti
bound_opts (dict): Options for bounding according to the auto_LiRPA documentation.
device (torch.device): Device to run the training on.
"""
super().__init__(model, eps, input_shape, train_eps_factor, lr, optimizer_func, bound_opts, device, checkpoint_save_path=checkpoint_save_path, checkpoint_save_interval=checkpoint_save_interval)
super().__init__(model, eps, input_shape, train_eps_factor, lr, optimizer_func, lr_scheduler_func, lr_decay_kwargs, bound_opts, device, checkpoint_save_path=checkpoint_save_path, checkpoint_save_interval=checkpoint_save_interval)
self.cert_train_method = 'taps'
self.num_epochs = num_epochs
self.lr = lr
self.warm_up_epochs = warm_up_epochs
self.ramp_up_epochs = ramp_up_epochs
self.lr_decay_factor = lr_decay_factor
self.lr_decay_milestones = lr_decay_milestones
self.gradient_clip = gradient_clip
self.l1_reg_weight = l1_reg_weight
self.shi_reg_weight = shi_reg_weight
Expand Down Expand Up @@ -128,8 +126,7 @@ def train_model(self, train_loader, val_loader=None, start_epoch=0):
eps_schedule=(self.warm_up_epochs, self.ramp_up_epochs),
eps_scheduler_args={},
optimizer=self.optimizer,
lr_decay_schedule=self.lr_decay_milestones,
lr_decay_factor=self.lr_decay_factor,
lr_scheduler=self.lr_scheduler,
n_classes=self.n_classes,
gradient_clip=self.gradient_clip,
l1_regularisation_weight=self.l1_reg_weight,
Expand Down
23 changes: 7 additions & 16 deletions CTRAIN/train/certified/crown_ibp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ def crown_ibp_train_model(
eps_schedule_unit="epoch",
eps_scheduler_args=dict(),
optimizer=None,
lr_decay_schedule=(15, 25),
lr_decay_factor=0.2,
lr_scheduler=None,
lr_decay_schedule_unit="epoch",
n_classes=10,
loss_fusion=False,
Expand Down Expand Up @@ -105,13 +104,6 @@ def crown_ibp_train_model(
epoch_rob_err = 0
epoch_nat_err = 0

if lr_decay_schedule_unit == "epoch":
if epoch + 1 in lr_decay_schedule:
print("LEARNING RATE DECAYED!")
cur_lr = cur_lr * lr_decay_factor
for g in optimizer.param_groups:
g["lr"] = cur_lr

print(
f"[{epoch + 1}/{num_epochs}]: eps {eps_scheduler.get_cur_eps(normalise=False):.4f}"
)
Expand All @@ -138,13 +130,6 @@ def crown_ibp_train_model(
),
)

if lr_decay_schedule_unit == "batch":
if no_batches + 1 in lr_decay_schedule:
print("LEARNING RATE DECAYED!")
cur_lr = cur_lr * lr_decay_factor
for g in optimizer.param_groups:
g["lr"] = cur_lr

data, target = data.to(device), target.to(device)
optimizer.zero_grad()

Expand Down Expand Up @@ -222,6 +207,9 @@ def crown_ibp_train_model(
eps_scheduler.batch_step()
no_batches += 1

if lr_scheduler is not None and lr_decay_schedule_unit == "batch":
lr_scheduler.step()

train_acc_nat = 1 - epoch_nat_err / len(train_loader)
train_acc_cert = 1 - epoch_rob_err / len(train_loader)

Expand All @@ -237,5 +225,8 @@ def crown_ibp_train_model(
hardened_model, optimizer, running_loss, epoch + 1, results_path
)

if lr_scheduler is not None and lr_decay_schedule_unit == "epoch":
lr_scheduler.step()

return hardened_model

Loading