KC123hello's picture
Upload Files
fc0ff8f verified
raw
history blame
5.58 kB
# Code taken and adapted from https://github.com/wagnermoritz/GSE
from vlm_eval.attacks.attack import Attack
import torch
import numpy as np
class PGD0(Attack):
def __init__(self, model, *args, img_range=(0, 1), k=5000, n_restarts=1,
targeted=False, iters=200, stepsize=120000/255.0, eps=4./255.,ver=False,mask_out='none',**kwargs):
'''
Implementation of the PGD0 attack https://arxiv.org/pdf/1909.05040
Author's implementation: https://github.com/fra31/sparse-imperceivable-attacks/tree/master
Addapted from: https://github.com/wagnermoritz/GSE/tree/main
args:
model: Callable, PyTorch classifier.
img_range: Tuple of ints/floats, lower and upper bound of image
entries.
targeted: Bool, given label is used as a target label if True.
k: Int, sparsity parameter.
n_restarts: Int, number of restarts from random perturbation.
iters: Int, number of gradient descent steps per restart.
stepsize: Float, step size for gradient descent.
'''
super().__init__(model, img_range=img_range, targeted=targeted)
self.k = k
self.n_restarts = n_restarts
self.eps = eps
self.iters = iters
self.stepsize = stepsize
if mask_out != 'none':
self.mask_out = mask_out
else:
self.mask_out = None
self.ver = ver
def _set_mask(self, data):
mask = torch.ones_like(data)
if self.mask_out == 'context':
mask[:, :-1, ...] = 0
elif self.mask_out == 'query':
mask[:, -1, ...] = 0
elif isinstance(self.mask_out, int):
mask[:, self.mask_out, ...] = 0
elif self.mask_out is None:
pass
else:
raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
return mask
def __call__(self, x, *args, **kwargs):
'''
Perform the PGD_0 attack on a batch of images x.
args:
x: Tensor of shape [B, C, H, W], batch of images.
y: Tensor of shape [B], batch of labels.
Returns a tensor of the same shape as x containing adversarial examples
'''
for param in self.model.model.parameters():
param.requires_grad = False
mask_out = self._set_mask(x)
x = x.to(self.device)
B, C, H, W = x.shape[1], x.shape[3], x.shape[4], x.shape[5]
for _ in range(self.n_restarts):
if not len(x):
break
eps = torch.full_like(x, self.eps)
lb, ub = torch.maximum(-eps, -x),torch.minimum(eps, 1.0 - x) #self.img_range[0] - x, self.img_range[1] - x
pert = (torch.clamp(x + (ub - lb) * torch.rand_like(x) + lb, *self.img_range) - x).view(B, C, H, W) * mask_out.view(B, C, H, W)
pert = self.project_L0(pert, lb, ub) # pert is of the shape (B, C, H, W)
for _ in range(self.iters):
pert.requires_grad = True
loss = self.lossfn(x=x, pert=pert.view(*x.shape), mask_out=mask_out)
loss.backward()
if self.ver and _ % 20 == 0:
print(f"Loss: {loss}, Iter: {_}")
grad = pert.grad.data.view(B,C,H,W) * mask_out.view(B, C, H, W) # shape (B, C, H, W)
with torch.no_grad():
grad /= grad.abs().sum(dim=(1,2,3), keepdim=True) + 1e-10
pert += (torch.rand_like(x) - .5).view(B, C, H, W) * 1e-12 - self.stepsize * grad
pert = self.project_L0(pert, lb, ub)
return (x + pert.view(*x.shape) * mask_out).detach()
def project_L0_sigma(self, pert, sigma, kappa, x_orig):
B, C, H, W = pert.shape
x = torch.clone(pert)
p1 = (1.0 / torch.maximum(1e-12, sigma) * (x_orig > 0).float()) + \
(1e12 * (x_orig == 0).float())
p2 = (1.0 / torch.maximum(torch.tensor(1e-12), sigma)) * \
(1.0 / torch.maximum(torch.tensor(1e-12), x_orig) - 1) * \
(x_orig > 0).float() + 1e12 * (x_orig == 0).float() + 1e12 * (sigma == 0).float()
lmbd_l = torch.maximum(-kappa, torch.amax(-p1, dim=1, keepdim=True))
lmbd_u = torch.minimum(kappa, torch.amin(p2, dim=1, keepdim=True))
lmbd_unconstr = torch.sum((pert - x_orig) * sigma * x_orig, dim=1, keepdim=True) / torch.clamp(torch.sum((sigma * x_orig) ** 2, dim=1, keepdim=True), min=1e-12)
lmbd = torch.maximum(lmbd_l, torch.minimum(lmbd_unconstr, lmbd_u))
return 0
def project_L0(self, pert, lb, ub):
'''
Project a batch of perturbations such that at most self.k pixels
are perturbed and componentwise there holds lb <= pert <= ub.
'''
B, C, H, W = pert.shape # Here, pert is of the shape B, C, H, W
p1 = torch.sum(pert ** 2, dim=1)
p2 = torch.clamp(torch.minimum(ub.view(B, C, H, W) - pert, pert - lb.view(B, C, H, W)), 0)
p2 = torch.sum(p2 ** 2, dim=1)
p3 = torch.topk(-1 * (p1 - p2).view(p1.size(0), -1), k=H*W-self.k, dim=-1)[1]
pert = torch.maximum(torch.minimum(pert, ub.view(B, C, H, W)), lb.view(B, C, H, W))
pert[torch.arange(0, B).view(-1, 1), :, p3//W, p3%H] = 0
return pert
def lossfn(self, x, pert, mask_out):
'''
Compute the loss at x.
'''
return (2 * self.targeted - 1) * self.model(x + pert * mask_out).sum()