File size: 3,233 Bytes
fc0ff8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# Code taken and adapted from https://github.com/wagnermoritz/GSE

import torch
from vlm_eval.attacks.attack import Attack
import math

class IHT(Attack):

    def __init__(self, model, targeted=False, img_range=(0, 1), steps=100, prox='hard',ver=False, lam=5e-5, mask_out='none',stepsize=0.015,eps=4./255.):
        super().__init__(model, targeted=targeted, img_range=img_range)
        self.steps = steps
        self.stepsize = stepsize
        self.ver = ver
        self.lam = lam
        self.eps = eps
        if mask_out != 'none':
            self.mask_out = mask_out
        else:
            self.mask_out = None
        if prox == 'hard':
            self.Prox = self.hardprox
        else:
            raise NotImplementedError



    def _set_mask(self, data):
        mask = torch.ones_like(data)
        if self.mask_out == 'context':
            mask[:, :-1, ...] = 0
        elif self.mask_out == 'query':
            mask[:, -1, ...] = 0
        elif isinstance(self.mask_out, int):
            mask[:, self.mask_out, ...] = 0
        elif self.mask_out is None:
            pass
        else:
            raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
        return mask
    
    def __call__(self, img):
        
        for param in self.model.model.parameters():
            param.requires_grad = False

        img = img.to(self.device)
        mask_out = self._set_mask(img)       
        x = torch.zeros_like(img) # perturbation to optimize
        z = x.clone() # used for FISTA extrapolation
        t = 1
        if self.ver:
            print('')

        for i in range(self.steps):
                # compue gradient
            x.requires_grad = True
            loss = self.model(img + x).sum() if self.targeted else -self.model(img + x).sum()
            loss.backward()
            x_grad = x.grad.data * mask_out
            x = x.detach()

            if self.ver and i % 20 == 0:
                print(f'Iteration: {i+1}, Loss: {loss}\n', end='')

                # FISTA update
            with torch.no_grad():
                t_ = .5 * (1 + math.sqrt(1 + 4 * t ** 2))
                alpha = (t - 1) / t_
                t = t_
                z_ = self.Prox(x=x - self.stepsize * x_grad, 
                               lam=self.lam * self.stepsize, 
                               img=img, 
                               eps=self.eps
                )
                x = z_ + alpha * (z_ - z)
                x = torch.clamp(x,-self.eps,self.eps)
                z = z_.clone()
                x = torch.clamp(img + x, *self.img_range) - img

        if self.ver:
            print('')
        print(f"L0 pert norm: {x.norm(p=0)}")
        
        return (img + x * mask_out).detach(), x.norm(p=0).item()

    def hardprox(self, x, lam, img, eps):
        '''
        Computes the hard thresholding proximal operator of the the
        perturbation x.

        :x:   Perturbation after gradient descent step.
        :lam: Regularization parameter.
        '''
        x_proj = torch.clamp(x,-eps,eps)
        x_temp = torch.clamp(img + x_proj,*self.img_range)
        x_proj = x_temp - img
        return torch.where(x ** 2 - (x_proj - x) ** 2 > 2 * lam, x_proj, 0)