Spaces:
Runtime error
Runtime error
File size: 5,002 Bytes
fc0ff8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# Code adapted from https://github.com/wagnermoritz/GSE
from vlm_eval.attacks.attack import Attack
import torch
import math
import time
class SAIF(Attack):
def __init__(self, model, *args, targeted=False, img_range=(-1, 1), steps=200,
r0=1, ver=False, k=10000, eps=16./255., mask_out='none', **kwargs):
'''
Adapted from: https://github.com/wagnermoritz/GSE/tree/main
Implementation of the sparse Frank-Wolfe attack SAIF
https://arxiv.org/pdf/2212.07495.pdf
args:
model: Callable, PyTorch classifier.
img_range: Tuple of ints/floats, lower and upper bound of image
entries.
targeted: Bool, given label is used as a target label if True.
steps: Int, number of FW iterations.
r0: Int, parameter for step size computation.
ver: Bool, print progress if True.
'''
super().__init__(model, targeted=targeted, img_range=img_range)
self.steps = steps
self.r0 = r0
self.loss_fn = torch.nn.CrossEntropyLoss()
self.ver = ver
self.k = k
self.eps = eps
if mask_out != 'none':
self.mask_out = mask_out
else:
self.mask_out = None
def _set_mask(self, data):
mask = torch.ones_like(data)
if self.mask_out == 'context':
mask[:, :-1, ...] = 0
elif self.mask_out == 'query':
mask[:, -1, ...] = 0
elif isinstance(self.mask_out, int):
mask[:, self.mask_out, ...] = 0
elif self.mask_out is None:
pass
else:
raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
return mask
def __call__(self, x):
'''
Perform the attack on a batch of images x.
args:
x: Tensor of shape [B, C, H, W], batch of images.
k: Int, sparsity parameter,
eps: Float, perturbation magnitude parameter.
Returns a tensor of the same shape as x containing adversarial examples.
'''
assert x.shape[0] == 1, "Only support batch size 1 for now"
for param in self.model.model.parameters():
param.requires_grad = False
B, C, H, W = x.shape[1], x.shape[3], x.shape[4], x.shape[5]
x = x.to(self.device)
batchidx = torch.arange(B).view(-1, 1)
mask_out = self._set_mask(x)
# compute p_0 and s_0
x_ = x.clone()
x_.requires_grad = True
out = self.model(x_)
loss = -out.sum() if not self.targeted else out.sum()
x__grad = torch.autograd.grad(loss, [x_])[0].detach() * mask_out
p = -self.eps * x__grad.sign()
p = p.detach().half()
ksmallest = torch.topk(-x__grad.view(B, -1), self.k, dim=1)[1]
ksmask = torch.zeros((B, C * H * W), device=self.device)
ksmask[batchidx, ksmallest] = 1
s = torch.logical_and(ksmask.view(*x.shape), x__grad < 0).float()
s = s.detach().half()
r = self.r0
for t in range(self.steps):
if self.ver:
print(f'\r Iteration {t+1}/{self.steps}', end='')
p.requires_grad = True
s.requires_grad = True
D = self.Loss_fn(x, s, p, mask_out)
D.backward()
mp = p.grad * mask_out
ms = s.grad * mask_out
with torch.no_grad():
# inf-norm LMO
v = (-self.eps * mp.sign()).half()
# 1-norm LMO
ksmallest = torch.topk(-ms.view(B, -1), self.k, dim=1)[1]
ksmask = torch.zeros((B, C * H * W), device=self.device)
ksmask[batchidx, ksmallest] = 1
ksmask = ksmask.view(*x.shape) * mask_out
z = torch.logical_and(ksmask, ms < 0).float().half()
# update stepsize until primal progress is made
mu = 1 / (2 ** r * math.sqrt(t + 1))
progress_condition = (self.Loss_fn(x, s + mu * (z - s), p + mu * (v - p), mask_out)
> D)
while progress_condition:
r += 1
if r >= 50:
break
mu = 1 / (2 ** r * math.sqrt(t + 1))
progress_condition = (self.Loss_fn(x, s + mu * (z - s), p + mu * (v - p), mask_out)
> D)
p = p + mu * (v - p)
s = s + mu * (z - s)
x_adv = torch.clamp(x + p, *self.img_range)
p = x_adv - x
if self.ver and t % 10 == 0:
print(f" Loss: {D}")
if self.ver:
print('')
return (x + s * p * mask_out).detach(), torch.norm(s*p,p=0).item()
def Loss_fn(self, x, s, p, mask_out):
out = self.model(x + s * p * mask_out).sum()
if self.targeted:
return out
else:
return -out
|