Spaces:
Runtime error
Runtime error
File size: 4,755 Bytes
fc0ff8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# Code taken and adapted from https://github.com/wagnermoritz/GSE
import torch
from vlm_eval.attacks.attack import Attack
class EAD(Attack):
def __init__(self,model, targeted=False, img_range=(0,1), steps=100, beta=5e-5, mask_out='none', ver=False, binary_steps=2, step_size=1e-2, decision_rule='L1'):
super().__init__(model=model, targeted=targeted, img_range=img_range)
self.steps = steps
self.ver = ver
self.binary_steps = binary_steps
self.beta = beta
if mask_out != 'none':
self.mask_out = mask_out
else:
self.mask_out = None
self.decision_rule = decision_rule
self.ver = ver
self.step_size = step_size
def _set_mask(self, data):
mask = torch.ones_like(data)
if self.mask_out == 'context':
mask[:, :-1, ...] = 0
elif self.mask_out == 'query':
mask[:, -1, ...] = 0
elif isinstance(self.mask_out, int):
mask[:, self.mask_out, ...] = 0
elif self.mask_out is None:
pass
else:
raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
return mask
def __call__(self, x_orig):
for param in self.model.model.parameters():
param.requires_grad = False
mask_out = self._set_mask(x_orig)
c = 1e-1
c_upper = 10e+10
c_lower = 0
overall_best_attack = x_orig.clone()
overall_best_dist = torch.inf
overall_best_loss = 1e10
for binary_step in range(self.binary_steps):
global_step = 0
x = x_orig.clone().detach()
y = x_orig.clone().detach()
best_attack = x_orig.clone().detach()
best_dist = torch.inf
best_loss = 1e10
step_size = 1e-2
for step in range(self.steps):
y.requires_grad = True
_, loss = self.loss_fn(x=y, c=c, x_orig=x_orig)
loss.backward()
y_grad = y.grad.data * mask_out
with torch.no_grad():
x_new = self.project(x=y-step_size*y_grad, x_orig=x_orig)
step_size = (self.step_size - 0) * (1 - global_step / self.steps) ** 0.5 + 0
global_step += 1
y = x_new + (step / (step + 3)) * (x_new - x)
x = x_new
loss_model, loss = self.loss_fn(x=x, c=c, x_orig=x_orig)
if self.ver and step % 20 == 0:
print(f"Binary Step: {binary_step}, Iter: {step}, Loss: {loss.item()}, L0: {(x - x_orig).norm(p=0)}, Linf: {(x - x_orig).norm(p=torch.inf)}")
if self.decision_rule == 'L1':
if (x - x_orig).norm(p=1).item() < best_dist and loss_model < best_loss:
best_loss = loss_model
best_attack = x.clone()
best_dist = (x - x_orig).norm(p=1).item()
else:
raise NotImplementedError
# Updating c
if overall_best_dist > best_dist and best_loss < overall_best_loss:
overall_best_loss = best_loss
overall_best_dist = best_dist
overall_best_attack = best_attack.clone()
c_upper = min(c_upper, c)
if c_upper < 1e9:
c = (c_upper + c_lower) / 2
else:
c_lower = max(c_lower, c)
if c_upper < 1e9:
c = (c_lower + c_upper) / 2.0
else:
c *= 10
print(f"Final L0: {(overall_best_attack - x_orig).norm(p=0)}, Linf: {(overall_best_attack - x_orig).norm(p=torch.inf)}")
return overall_best_attack.detach()
def project(self, x, x_orig):
mask_1 = (x - x_orig > self.beta).float()
mask_2 = ((x - x_orig).abs() <= self.beta).float()
mask_3 = (x - x_orig < -self.beta).float()
upper = torch.minimum(x - self.beta, torch.tensor(1.0))
lower = torch.maximum(x + self.beta, torch.tensor(0.0))
proj_x = mask_1 * upper + mask_2 * x_orig + mask_3 * lower
return proj_x
def loss_fn(self, x, c, x_orig):
out = -self.model(x).sum() if not self.targeted else self.model(x).sum()
l2_dist = ((x - x_orig) ** 2).view(x.shape[0], -1).sum(dim=1)
l1_dist = ((x - x_orig).abs()).view(x.shape[0], -1).sum(dim=1)
return out, c * out + l2_dist.sum() + \
self.beta * l1_dist.sum()
|