File size: 6,588 Bytes
fc0ff8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# Code taken and adapted from https://github.com/wagnermoritz/GSE
from vlm_eval.attacks.attack import Attack
import torch

class SparseRS(Attack):
    def __init__(self, model, *args, targeted=False, img_range=(-1, 1),
                 n_queries=10000, k=100, n_restarts=10, alpha_init=0.8, mask_out='none',**kwargs):
        '''
        Implementation of the L0 variant SparseRS https://arxiv.org/abs/2006.12834
        Authors' implementation: https://github.com/fra31/sparse-rs
        Adapted from: https://github.com/wagnermoritz/GSE/tree/main
        
        args:
        model:         Callable, PyTorch classifier.
        targeted:      Bool, given label is used as a target label if True.
        img_range:     Tuple of ints/floats, lower and upper bound of image
                       entries.
        n_queries:     Int, max number of queries to the model
        k:             Int, initial sparsity parameter
        n_restarts:    Int, number of restarts with random initialization
        alpha_init:    Float, inital value for alpha schedule
        '''
        super().__init__(model, targeted=targeted, img_range=img_range)
        self.n_queries = n_queries
        self.k = k
        self.n_restarts = n_restarts
        self.alpha_init = alpha_init
        if mask_out != 'none':
            self.mask_out = mask_out
        else:
            self.mask_out = None

    def _set_mask(self, data):
        mask = torch.ones_like(data)
        if self.mask_out == 'context':
            mask[:, :-1, ...] = 0
        elif self.mask_out == 'query':
            mask[:, -1, ...] = 0
        elif isinstance(self.mask_out, int):
            mask[:, self.mask_out, ...] = 0
        elif self.mask_out is None:
            pass
        else:
            raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
        return mask


    def __call__(self, x, *args, **kwargs):
        '''
        Perform SparseRS L0 on a batch of images x with corresponding labels y.

        args:
        x:   Tensor of shape [B, C, H, W], batch of images.
        y:   Tensor of shape [B], batch of labels.

        Returns a tensor of the same shape as x containing adversarial examples
        '''

        for param in self.model.model.parameters():
            param.requires_grad = False

        torch.random.manual_seed(0)
        torch.cuda.random.manual_seed(0)
        x = x.to(self.device)

        with torch.no_grad():
            for _ in range(self.n_restarts):
                if len(x) == 0:
                    break

                x_adv = self.__perturb(x.clone())

        return x_adv.detach()
    

    def __perturb(self, x):
        '''
        Perform the attack from a random starting point.
        '''
        mask_out = self._set_mask(x)
        B, C, H, W = x.shape[1], x.shape[3], x.shape[4], x.shape[5]
        batchidx = torch.arange(B, device=self.device).view(-1, 1)
        result = x.clone().view(B, C, H, W)

        # M: set of perturbed pixel indices, U_M: set of unperturbed pixel indices
        batch_randperm = torch.rand(B, H * W, device=self.device).argsort(dim=1)
        M = batch_randperm[:, :self.k]
        U_M = batch_randperm[:, self.k:]
        result[batchidx, :, M//W, M%H] = self.__sampleDelta(B, C, self.k)

        best_loss = self.__lossfn(result.view(*x.shape))

        for i in range(1, self.n_queries):
            if B == 0:
                break
            # reset k_i currently perturbed pixels and perturb k_i new pixels
            k_i = max(int(self.__alphaSchedule(i) * self.k), 1)
            A_idx = torch.randperm(self.k, device=self.device)[:k_i]
            B_idx = torch.randperm(H * W - self.k, device=self.device)[:k_i]
            A_set, B_set = M[:, A_idx], U_M[:, B_idx]

            z = result.clone()
            z[batchidx, :, A_set//W, A_set%H] = x.view(B, C, H, W)[batchidx, :, A_set//W, A_set%H]
            if k_i > 1:
                z[batchidx, :, B_set//W, B_set%H] = self.__sampleDelta(B, C, k_i)
            else: # if only one pixel is changed, make sure it actually changes
                new_color = self.__sampleDelta(B, C, k_i)
                while (mask := (z[batchidx, :, B_set//W, B_set%H] == new_color).view(B, -1).all(dim=-1)).any():
                    new_color[mask] = self.__sampleDelta(mask.int().sum().item(), C, k_i)
                z[batchidx, :, B_set//W, B_set%H] = new_color

            # save perturbations that improved the loss/margin
            loss = self.__lossfn(z, y)
            mask = loss < best_loss
            best_loss[mask] = loss[mask]
            mask = torch.logical_or(mask, margin < -1e-6)
            if mask.any():
                #best_margin[mask] = margin[mask]
                tmp = result[active]
                tmp[mask] = z[mask]
                result[active] = tmp
                U_M[mask.nonzero().view(-1, 1), B_idx] = A_set[mask]
                M[mask.nonzero().view(-1, 1), A_idx] = B_set[mask]
            
            # stop working on successful adv examples
            mask = best_margin < 0
            if mask.any():
                mask = torch.logical_not(mask)
                active[active.clone()] = mask
                x, y, z, M, U_M = x[mask], y[mask], z[mask], M[mask], U_M[mask]
                best_margin, best_loss = best_margin[mask], best_loss[mask]
                B = len(y)
                batchidx = torch.arange(B, device=self.device).view(-1, 1)

        return result


    def __sampleDelta(self, B, C, k):
        '''
        Sample k-pixel perturbations for B images. Each pixel is assigned a
        random corner in the C-dimensional cube defined by self.img_range.
        '''
        fac = self.img_range[1] - self.img_range[0]
        return self.img_range[0] + fac * torch.randint(0, 1, [B, k, C],
                                                       dtype=torch.float,
                                                       device=self.device)
    

    def __alphaSchedule(self, iteration):
        '''
        Update number of pixels to perturb based in the current iteration.
        '''
        iteration = int(iteration / self.n_queries * 10000)
        factors = [1, 2, 4, 5, 6, 8, 10, 12, 15, 20]
        alpha_schedule = [10, 50, 200, 500, 1000, 2000, 4000, 6000, 8000]
        idx = bisect.bisect_left(alpha_schedule, iteration)
        return self.alpha_init / factors[idx]
    

    def __lossfn(self, x):
        '''
        Compute the loss depending on self.targeted.
        '''
        return self.model(x).sum() if self.targeted else -self.model(x).sum()