Spaces:
Runtime error
Runtime error
File size: 6,144 Bytes
fc0ff8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
# Code taken and adapted from https://github.com/wagnermoritz/GSE
import torch
import math
from vlm_eval.attacks.attack import Attack
class FWnucl(Attack):
def __init__(self, model, *args, iters=200, img_range=(-1, 1), ver=False,
targeted=False, eps=5, mask_out='none',**kwargs):
'''
Implementation of the nuclear group norm attack.
args:
model: Callable, PyTorch classifier.
ver: Bool, print progress if True.
img_range: Tuple of ints/floats, lower and upper bound of image
entries.
targeted: Bool, given label is used as a target label if True.
eps: Float, radius of the nuclear group norm ball.
'''
super().__init__(model, img_range=img_range, targeted=targeted)
self.iters = iters
self.ver = ver
self.eps = eps
self.gr = (math.sqrt(5) + 1) / 2
if mask_out != 'none':
self.mask_out = mask_out
else:
self.mask_out = None
def _set_mask(self, data):
mask = torch.ones_like(data)
if self.mask_out == 'context':
mask[:, :-1, ...] = 0
elif self.mask_out == 'query':
mask[:, -1, ...] = 0
elif isinstance(self.mask_out, int):
mask[:, self.mask_out, ...] = 0
elif self.mask_out is None:
pass
else:
raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
return mask
def __loss_fn(self, x):
'''
Compute loss depending on self.targeted.
'''
if self.targeted:
return -self.model(x).sum()
else:
return self.model(x).sum()
def __call__(self, x, *args, **kwargs):
'''
Perform the nuclear group norm attack on a batch of images x.
args:
x: Tensor of shape [B, C, H, W], batch of images.
y: Tensor of shape [B], batch of labels.
Returns a tensor of the same shape as x containing adversarial examples
'''
for param in self.model.model.parameters():
param.requires_grad = False
mask_out = self._set_mask(x)
x = x.to(self.device)
noise = torch.zeros_like(x)
noise.requires_grad = True
for t in range(self.iters):
if self.ver:
print(f'\rIteration {t+1}/{self.iters}', end='')
loss = self.__loss_fn(x + noise * mask_out)
loss.backward()
noise.grad.data = noise.grad.data * mask_out
s = self.__groupNuclearLMO(noise.grad.data, eps=self.eps)
with torch.no_grad():
gamma = self.__lineSearch(x=x, s=s, noise=noise)
noise = (1 - gamma) * noise + gamma * s
noise.requires_grad = True
if self.ver and t % 20 == 0:
print(f"Iteration: {t}, Loss: {loss.item()}")
x = torch.clamp(x + noise, 0, 1)
if self.ver:
print("")
return x.detach()
def __lineSearch(self, x, s, noise, steps=25):
'''
Perform line search for the step size.
'''
a = torch.zeros(x.shape[1], device=self.device).view(-1, 1, 1, 1)
b = torch.ones(x.shape[1], device=self.device).view(-1, 1, 1, 1)
c = b - (b - a) / self.gr
d = a + (b - a) / self.gr
sx = s - noise
for i in range(steps):
loss1 = self.__loss_fn(x + noise + (c * sx).view(*x.shape))
loss2 = self.__loss_fn(x + noise + (d * sx).view(*x.shape))
mask = loss1 > loss2
b[mask] = d[mask]
mask = torch.logical_not(mask)
a[mask] = c[mask]
c = b - (b - a) / self.gr
d = a + (b - a) / self.gr
return (b + a) / 2
def __groupNuclearLMO(self, x, eps=5):
'''
LMO for the nuclear group norm ball.
'''
B, C, H, W = x.shape[1], x.shape[3], x.shape[4], x.shape[5]
size = 32 if H > 64 else 4
# turn batch of images into batch of size by size pixel groups per
# color channel
xrgb = [x.view(B, C, H, W)[:, c, :, :] for c in range(C)]
xrgb = [xc.unfold(1, size, size).unfold(2, size, size) for xc in xrgb]
xrgb = [xc.reshape(-1, size, size) for xc in xrgb]
# compute nuclear norm of each patch (sum norms over color channels)
norms = torch.linalg.svdvals(xrgb[0])
for xc in xrgb[1:]:
norms += torch.linalg.svdvals(xc)
norms = norms.sum(-1).reshape(B, -1)
# only keep the patch g* with the largest nuclear norm for each image
idxs = norms.argmax(dim=1).view(-1, 1)
xrgb = [xc.reshape(B, -1, size, size) for xc in xrgb]
xrgb = [xc[torch.arange(B).view(-1, 1), idxs].view(B, size, size)
for xc in xrgb]
# build index tensor corr. to the position of the kept patches in x
off = (idxs % (W / size)).long() * size
off += torch.floor(idxs / (W / size)).long() * W * size
idxs = torch.arange(0, size**2,
device=self.device).view(1, -1).repeat(B, 1) + off
off = torch.arange(0, size,
device=self.device).view(-1, 1).repeat(1, size)
off = off * W - off * size
idxs += off.view(1, -1)
# compute singular vector pairs corresponding to largest singular value
# and final perturbation (LMO solution)
pert = torch.zeros_like(x).view(B, C, H, W)
for i, xc in enumerate(xrgb):
U, _, V = torch.linalg.svd(xc)
U = U[:, :, 0].view(B, size, 1)
V = V.transpose(-2, -1)[:, :, 0].view(B, size, 1)
pert_gr = torch.bmm(U, V.transpose(-2, -1)).reshape(B, size * size)
idx = torch.arange(B).view(-1, 1)
pert_tmp = pert[:, i, :, :].view(B, -1)
pert_tmp[idx, idxs] = pert_gr * eps
pert_clone = pert.clone()
pert_clone[:, i, :, :] = pert_tmp.view(B, H, W)
return pert_clone.view(*x.shape)
|