File size: 5,002 Bytes
fc0ff8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Code adapted from https://github.com/wagnermoritz/GSE

from vlm_eval.attacks.attack import Attack
import torch
import math
import time

class SAIF(Attack):
    def __init__(self, model, *args, targeted=False, img_range=(-1, 1), steps=200,
                 r0=1, ver=False, k=10000, eps=16./255., mask_out='none', **kwargs):
        '''
        Adapted from: https://github.com/wagnermoritz/GSE/tree/main
        Implementation of the sparse Frank-Wolfe attack SAIF
        https://arxiv.org/pdf/2212.07495.pdf

        args:
        model:         Callable, PyTorch classifier.
        img_range:     Tuple of ints/floats, lower and upper bound of image
                       entries.
        targeted:      Bool, given label is used as a target label if True.
        steps:         Int, number of FW iterations.
        r0:            Int, parameter for step size computation.
        ver:           Bool, print progress if True.
        '''
        super().__init__(model, targeted=targeted, img_range=img_range)
        self.steps = steps
        self.r0 = r0
        self.loss_fn = torch.nn.CrossEntropyLoss()
        self.ver = ver
        self.k = k
        self.eps = eps
        if mask_out != 'none':
            self.mask_out = mask_out
        else:
            self.mask_out = None

    def _set_mask(self, data):
        mask = torch.ones_like(data)
        if self.mask_out == 'context':
            mask[:, :-1, ...] = 0
        elif self.mask_out == 'query':
            mask[:, -1, ...] = 0
        elif isinstance(self.mask_out, int):
            mask[:, self.mask_out, ...] = 0
        elif self.mask_out is None:
            pass
        else:
            raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
        return mask

    def __call__(self, x):
        '''
        Perform the  attack on a batch of images x.

        args:
        x:   Tensor of shape [B, C, H, W], batch of images.
        k:   Int, sparsity parameter,
        eps: Float, perturbation magnitude parameter.

        Returns a tensor of the same shape as x containing adversarial examples.
        '''
        assert x.shape[0] == 1, "Only support batch size 1 for now"



        for param in self.model.model.parameters():
            param.requires_grad = False

        B, C, H, W = x.shape[1], x.shape[3], x.shape[4], x.shape[5]
        x = x.to(self.device)
        batchidx = torch.arange(B).view(-1, 1)

        mask_out = self._set_mask(x)
        # compute p_0 and s_0
        x_ = x.clone()
        x_.requires_grad = True
        out = self.model(x_)
        loss = -out.sum() if not self.targeted else out.sum()
        x__grad = torch.autograd.grad(loss, [x_])[0].detach() * mask_out
        p = -self.eps * x__grad.sign()
        p = p.detach().half()
        ksmallest = torch.topk(-x__grad.view(B, -1), self.k, dim=1)[1]
        ksmask = torch.zeros((B, C * H * W), device=self.device)
        ksmask[batchidx, ksmallest] = 1
        s = torch.logical_and(ksmask.view(*x.shape), x__grad < 0).float()
        s = s.detach().half()

        r = self.r0


        for t in range(self.steps):
            if self.ver:
                print(f'\r Iteration {t+1}/{self.steps}', end='')
            p.requires_grad = True
            s.requires_grad = True

            D = self.Loss_fn(x, s, p, mask_out)
            D.backward()

            mp = p.grad * mask_out
            ms = s.grad * mask_out
            with torch.no_grad():
                # inf-norm LMO
                v = (-self.eps * mp.sign()).half()

                # 1-norm LMO
                ksmallest = torch.topk(-ms.view(B, -1), self.k, dim=1)[1]
                ksmask = torch.zeros((B, C * H * W), device=self.device)
                ksmask[batchidx, ksmallest] = 1
                ksmask = ksmask.view(*x.shape) * mask_out
                z = torch.logical_and(ksmask, ms < 0).float().half()
                # update stepsize until primal progress is made
                mu = 1 / (2 ** r * math.sqrt(t + 1))
                progress_condition = (self.Loss_fn(x, s + mu * (z - s), p + mu * (v - p), mask_out)
                    > D)

                while progress_condition:
                    r += 1
                    if r >= 50:
                        break
                    mu = 1 / (2 ** r * math.sqrt(t + 1))
                    progress_condition = (self.Loss_fn(x, s + mu * (z - s), p + mu * (v - p), mask_out)
                    > D)


                p = p + mu * (v - p)
                s = s + mu * (z - s)

                x_adv = torch.clamp(x + p, *self.img_range)
                p = x_adv - x

            if self.ver and t % 10 == 0:
                print(f" Loss: {D}")
        if self.ver:
            print('')
        return (x + s * p * mask_out).detach(), torch.norm(s*p,p=0).item()

    def Loss_fn(self, x, s, p, mask_out):
        out = self.model(x + s * p * mask_out).sum()
        if self.targeted:
            return out
        else:
            return -out