KC123hello's picture
Upload Files
fc0ff8f verified
raw
history blame
4.76 kB
# Code taken and adapted from https://github.com/wagnermoritz/GSE
import torch
from vlm_eval.attacks.attack import Attack
class EAD(Attack):
def __init__(self,model, targeted=False, img_range=(0,1), steps=100, beta=5e-5, mask_out='none', ver=False, binary_steps=2, step_size=1e-2, decision_rule='L1'):
super().__init__(model=model, targeted=targeted, img_range=img_range)
self.steps = steps
self.ver = ver
self.binary_steps = binary_steps
self.beta = beta
if mask_out != 'none':
self.mask_out = mask_out
else:
self.mask_out = None
self.decision_rule = decision_rule
self.ver = ver
self.step_size = step_size
def _set_mask(self, data):
mask = torch.ones_like(data)
if self.mask_out == 'context':
mask[:, :-1, ...] = 0
elif self.mask_out == 'query':
mask[:, -1, ...] = 0
elif isinstance(self.mask_out, int):
mask[:, self.mask_out, ...] = 0
elif self.mask_out is None:
pass
else:
raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
return mask
def __call__(self, x_orig):
for param in self.model.model.parameters():
param.requires_grad = False
mask_out = self._set_mask(x_orig)
c = 1e-1
c_upper = 10e+10
c_lower = 0
overall_best_attack = x_orig.clone()
overall_best_dist = torch.inf
overall_best_loss = 1e10
for binary_step in range(self.binary_steps):
global_step = 0
x = x_orig.clone().detach()
y = x_orig.clone().detach()
best_attack = x_orig.clone().detach()
best_dist = torch.inf
best_loss = 1e10
step_size = 1e-2
for step in range(self.steps):
y.requires_grad = True
_, loss = self.loss_fn(x=y, c=c, x_orig=x_orig)
loss.backward()
y_grad = y.grad.data * mask_out
with torch.no_grad():
x_new = self.project(x=y-step_size*y_grad, x_orig=x_orig)
step_size = (self.step_size - 0) * (1 - global_step / self.steps) ** 0.5 + 0
global_step += 1
y = x_new + (step / (step + 3)) * (x_new - x)
x = x_new
loss_model, loss = self.loss_fn(x=x, c=c, x_orig=x_orig)
if self.ver and step % 20 == 0:
print(f"Binary Step: {binary_step}, Iter: {step}, Loss: {loss.item()}, L0: {(x - x_orig).norm(p=0)}, Linf: {(x - x_orig).norm(p=torch.inf)}")
if self.decision_rule == 'L1':
if (x - x_orig).norm(p=1).item() < best_dist and loss_model < best_loss:
best_loss = loss_model
best_attack = x.clone()
best_dist = (x - x_orig).norm(p=1).item()
else:
raise NotImplementedError
# Updating c
if overall_best_dist > best_dist and best_loss < overall_best_loss:
overall_best_loss = best_loss
overall_best_dist = best_dist
overall_best_attack = best_attack.clone()
c_upper = min(c_upper, c)
if c_upper < 1e9:
c = (c_upper + c_lower) / 2
else:
c_lower = max(c_lower, c)
if c_upper < 1e9:
c = (c_lower + c_upper) / 2.0
else:
c *= 10
print(f"Final L0: {(overall_best_attack - x_orig).norm(p=0)}, Linf: {(overall_best_attack - x_orig).norm(p=torch.inf)}")
return overall_best_attack.detach()
def project(self, x, x_orig):
mask_1 = (x - x_orig > self.beta).float()
mask_2 = ((x - x_orig).abs() <= self.beta).float()
mask_3 = (x - x_orig < -self.beta).float()
upper = torch.minimum(x - self.beta, torch.tensor(1.0))
lower = torch.maximum(x + self.beta, torch.tensor(0.0))
proj_x = mask_1 * upper + mask_2 * x_orig + mask_3 * lower
return proj_x
def loss_fn(self, x, c, x_orig):
out = -self.model(x).sum() if not self.targeted else self.model(x).sum()
l2_dist = ((x - x_orig) ** 2).view(x.shape[0], -1).sum(dim=1)
l1_dist = ((x - x_orig).abs()).view(x.shape[0], -1).sum(dim=1)
return out, c * out + l2_dist.sum() + \
self.beta * l1_dist.sum()