KC123hello's picture
Upload Files
fc0ff8f verified
# Code taken from https://github.com/chs20/RobustVLM/tree/main
import torch
from vlm_eval.attacks.utils import project_perturbation, normalize_grad
class PGD:
"""
Minimize or maximize given loss
"""
def __init__(self, forward, norm, eps, mode='min', mask_out='context', image_space=True):
self.model = forward
self.norm = norm
self.eps = eps
self.momentum = 0.9
self.mode = mode
self.mask_out = mask_out
self.image_space = image_space
def perturb(self, data_clean, iterations, stepsize, perturbation=None, verbose=False, return_loss=False):
if self.image_space:
# make sure data is in image space
assert torch.max(data_clean) < 1. + 1e-6 and torch.min(data_clean) > -1e-6 # todo
if perturbation is None:
perturbation = torch.zeros_like(data_clean, requires_grad=True)
mask = self._set_mask(data_clean)
velocity = torch.zeros_like(data_clean)
for i in range(iterations):
perturbation.requires_grad_()
with torch.enable_grad():
loss = self.model(data_clean + perturbation)
# print 10 times in total and last iteration
if verbose and (i % (iterations // 10 + 1) == 0 or i == iterations - 1):
print(f'[iteration] {i} [loss] {loss.item()}')
with torch.no_grad():
gradient = torch.autograd.grad(loss, perturbation)[0]
gradient = mask * gradient
if gradient.isnan().any(): #
print(f'attention: nan in gradient ({gradient.isnan().sum()})') #
gradient[gradient.isnan()] = 0.
# normalize
gradient = normalize_grad(gradient, p=self.norm)
# momentum
velocity = self.momentum * velocity + gradient
velocity = normalize_grad(velocity, p=self.norm)
# update
if self.mode == 'min':
perturbation = perturbation - stepsize * velocity
elif self.mode == 'max':
perturbation = perturbation + stepsize * velocity
else:
raise ValueError(f'Unknown mode: {self.mode}')
# project
perturbation = project_perturbation(perturbation, self.eps, self.norm)
if self.image_space:
perturbation = torch.clamp(
data_clean + perturbation, 0, 1
) - data_clean # clamp to image space
assert torch.max(data_clean + perturbation) < 1. + 1e-6 and torch.min(
data_clean + perturbation
) > -1e-6
assert not perturbation.isnan().any()
# assert (ctorch.compute_norm(perturbation, p=self.norm) <= self.eps + 1e-6).all()
# todo return best perturbation
# problem is that model currently does not output expanded loss
if return_loss:
return data_clean + perturbation.detach(), loss
else:
return data_clean + perturbation.detach()
def _set_mask(self, data):
mask = torch.ones_like(data)
if self.mask_out == 'context':
mask[:, :-1, ...] = 0
elif self.mask_out == 'query':
mask[:, -1, ...] = 0
elif isinstance(self.mask_out, int):
mask[:, self.mask_out, ...] = 0
elif self.mask_out is None:
pass
else:
raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
return mask