File size: 4,755 Bytes
fc0ff8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Code taken and adapted from https://github.com/wagnermoritz/GSE
import torch
from vlm_eval.attacks.attack import Attack

class EAD(Attack):

    def __init__(self,model, targeted=False, img_range=(0,1), steps=100, beta=5e-5, mask_out='none', ver=False, binary_steps=2, step_size=1e-2, decision_rule='L1'):
        
        super().__init__(model=model, targeted=targeted, img_range=img_range)
        self.steps = steps
        self.ver = ver
        self.binary_steps = binary_steps
        self.beta = beta
        if mask_out != 'none':
            self.mask_out = mask_out
        else:
            self.mask_out = None
        self.decision_rule = decision_rule
        self.ver = ver
        self.step_size = step_size

    def _set_mask(self, data):
        mask = torch.ones_like(data)
        if self.mask_out == 'context':
            mask[:, :-1, ...] = 0
        elif self.mask_out == 'query':
            mask[:, -1, ...] = 0
        elif isinstance(self.mask_out, int):
            mask[:, self.mask_out, ...] = 0
        elif self.mask_out is None:
            pass
        else:
            raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
        return mask

    def __call__(self, x_orig):
        
        for param in self.model.model.parameters():
            param.requires_grad = False

        mask_out = self._set_mask(x_orig)

        c = 1e-1
        c_upper = 10e+10
        c_lower = 0
        
        overall_best_attack = x_orig.clone()
        overall_best_dist = torch.inf
        overall_best_loss = 1e10

        for binary_step in range(self.binary_steps):

            global_step = 0      
            x = x_orig.clone().detach()
            y = x_orig.clone().detach()   

            best_attack = x_orig.clone().detach()
            best_dist = torch.inf
            best_loss = 1e10

            step_size = 1e-2

            for step in range(self.steps):

                y.requires_grad = True
                _, loss = self.loss_fn(x=y, c=c, x_orig=x_orig)
                loss.backward()
                y_grad = y.grad.data * mask_out

                with torch.no_grad():
                    x_new = self.project(x=y-step_size*y_grad, x_orig=x_orig)

                    step_size = (self.step_size - 0) * (1 - global_step / self.steps) ** 0.5 + 0
                    global_step += 1
                    
                    y = x_new + (step / (step + 3)) * (x_new - x)
                    x = x_new
                                    
                    loss_model, loss = self.loss_fn(x=x, c=c, x_orig=x_orig)

                    if self.ver and step % 20 == 0:
                        print(f"Binary Step: {binary_step}, Iter: {step}, Loss: {loss.item()}, L0: {(x - x_orig).norm(p=0)}, Linf: {(x - x_orig).norm(p=torch.inf)}")

                    if self.decision_rule == 'L1':
                        if (x - x_orig).norm(p=1).item() < best_dist and loss_model < best_loss:
                            best_loss = loss_model
                            best_attack = x.clone()
                            best_dist = (x - x_orig).norm(p=1).item()
                    else:
                        raise NotImplementedError

            # Updating c
            if overall_best_dist > best_dist and best_loss < overall_best_loss:
                overall_best_loss = best_loss
                overall_best_dist = best_dist
                overall_best_attack = best_attack.clone()

                c_upper = min(c_upper, c)
                if c_upper < 1e9:
                    c = (c_upper + c_lower) / 2

            else:
                c_lower = max(c_lower, c)
                if c_upper < 1e9:
                    c = (c_lower + c_upper) / 2.0
                else:
                    c *= 10
        
        print(f"Final L0: {(overall_best_attack - x_orig).norm(p=0)}, Linf: {(overall_best_attack - x_orig).norm(p=torch.inf)}")
        return overall_best_attack.detach()


    def project(self, x, x_orig):

        mask_1 = (x - x_orig > self.beta).float()
        mask_2 = ((x - x_orig).abs() <= self.beta).float()
        mask_3 = (x - x_orig < -self.beta).float()

        upper = torch.minimum(x - self.beta, torch.tensor(1.0))
        lower = torch.maximum(x + self.beta, torch.tensor(0.0))
        
        proj_x = mask_1 * upper + mask_2 * x_orig + mask_3 * lower
        return proj_x   
    
    def loss_fn(self, x, c, x_orig):
        
        out = -self.model(x).sum() if not self.targeted else self.model(x).sum()
        l2_dist = ((x - x_orig) ** 2).view(x.shape[0], -1).sum(dim=1)
        l1_dist = ((x - x_orig).abs()).view(x.shape[0], -1).sum(dim=1)

        return out, c * out + l2_dist.sum() + \
                    self.beta * l1_dist.sum()