Spaces:
Runtime error
Runtime error
| # Code taken and adapted from https://github.com/wagnermoritz/GSE | |
| from vlm_eval.attacks.attack import Attack | |
| import torch | |
| class SparseRS(Attack): | |
| def __init__(self, model, *args, targeted=False, img_range=(-1, 1), | |
| n_queries=10000, k=100, n_restarts=10, alpha_init=0.8, mask_out='none',**kwargs): | |
| ''' | |
| Implementation of the L0 variant SparseRS https://arxiv.org/abs/2006.12834 | |
| Authors' implementation: https://github.com/fra31/sparse-rs | |
| Adapted from: https://github.com/wagnermoritz/GSE/tree/main | |
| args: | |
| model: Callable, PyTorch classifier. | |
| targeted: Bool, given label is used as a target label if True. | |
| img_range: Tuple of ints/floats, lower and upper bound of image | |
| entries. | |
| n_queries: Int, max number of queries to the model | |
| k: Int, initial sparsity parameter | |
| n_restarts: Int, number of restarts with random initialization | |
| alpha_init: Float, inital value for alpha schedule | |
| ''' | |
| super().__init__(model, targeted=targeted, img_range=img_range) | |
| self.n_queries = n_queries | |
| self.k = k | |
| self.n_restarts = n_restarts | |
| self.alpha_init = alpha_init | |
| if mask_out != 'none': | |
| self.mask_out = mask_out | |
| else: | |
| self.mask_out = None | |
| def _set_mask(self, data): | |
| mask = torch.ones_like(data) | |
| if self.mask_out == 'context': | |
| mask[:, :-1, ...] = 0 | |
| elif self.mask_out == 'query': | |
| mask[:, -1, ...] = 0 | |
| elif isinstance(self.mask_out, int): | |
| mask[:, self.mask_out, ...] = 0 | |
| elif self.mask_out is None: | |
| pass | |
| else: | |
| raise NotImplementedError(f'Unknown mask_out: {self.mask_out}') | |
| return mask | |
| def __call__(self, x, *args, **kwargs): | |
| ''' | |
| Perform SparseRS L0 on a batch of images x with corresponding labels y. | |
| args: | |
| x: Tensor of shape [B, C, H, W], batch of images. | |
| y: Tensor of shape [B], batch of labels. | |
| Returns a tensor of the same shape as x containing adversarial examples | |
| ''' | |
| for param in self.model.model.parameters(): | |
| param.requires_grad = False | |
| torch.random.manual_seed(0) | |
| torch.cuda.random.manual_seed(0) | |
| x = x.to(self.device) | |
| with torch.no_grad(): | |
| for _ in range(self.n_restarts): | |
| if len(x) == 0: | |
| break | |
| x_adv = self.__perturb(x.clone()) | |
| return x_adv.detach() | |
| def __perturb(self, x): | |
| ''' | |
| Perform the attack from a random starting point. | |
| ''' | |
| mask_out = self._set_mask(x) | |
| B, C, H, W = x.shape[1], x.shape[3], x.shape[4], x.shape[5] | |
| batchidx = torch.arange(B, device=self.device).view(-1, 1) | |
| result = x.clone().view(B, C, H, W) | |
| # M: set of perturbed pixel indices, U_M: set of unperturbed pixel indices | |
| batch_randperm = torch.rand(B, H * W, device=self.device).argsort(dim=1) | |
| M = batch_randperm[:, :self.k] | |
| U_M = batch_randperm[:, self.k:] | |
| result[batchidx, :, M//W, M%H] = self.__sampleDelta(B, C, self.k) | |
| best_loss = self.__lossfn(result.view(*x.shape)) | |
| for i in range(1, self.n_queries): | |
| if B == 0: | |
| break | |
| # reset k_i currently perturbed pixels and perturb k_i new pixels | |
| k_i = max(int(self.__alphaSchedule(i) * self.k), 1) | |
| A_idx = torch.randperm(self.k, device=self.device)[:k_i] | |
| B_idx = torch.randperm(H * W - self.k, device=self.device)[:k_i] | |
| A_set, B_set = M[:, A_idx], U_M[:, B_idx] | |
| z = result.clone() | |
| z[batchidx, :, A_set//W, A_set%H] = x.view(B, C, H, W)[batchidx, :, A_set//W, A_set%H] | |
| if k_i > 1: | |
| z[batchidx, :, B_set//W, B_set%H] = self.__sampleDelta(B, C, k_i) | |
| else: # if only one pixel is changed, make sure it actually changes | |
| new_color = self.__sampleDelta(B, C, k_i) | |
| while (mask := (z[batchidx, :, B_set//W, B_set%H] == new_color).view(B, -1).all(dim=-1)).any(): | |
| new_color[mask] = self.__sampleDelta(mask.int().sum().item(), C, k_i) | |
| z[batchidx, :, B_set//W, B_set%H] = new_color | |
| # save perturbations that improved the loss/margin | |
| loss = self.__lossfn(z, y) | |
| mask = loss < best_loss | |
| best_loss[mask] = loss[mask] | |
| mask = torch.logical_or(mask, margin < -1e-6) | |
| if mask.any(): | |
| #best_margin[mask] = margin[mask] | |
| tmp = result[active] | |
| tmp[mask] = z[mask] | |
| result[active] = tmp | |
| U_M[mask.nonzero().view(-1, 1), B_idx] = A_set[mask] | |
| M[mask.nonzero().view(-1, 1), A_idx] = B_set[mask] | |
| # stop working on successful adv examples | |
| mask = best_margin < 0 | |
| if mask.any(): | |
| mask = torch.logical_not(mask) | |
| active[active.clone()] = mask | |
| x, y, z, M, U_M = x[mask], y[mask], z[mask], M[mask], U_M[mask] | |
| best_margin, best_loss = best_margin[mask], best_loss[mask] | |
| B = len(y) | |
| batchidx = torch.arange(B, device=self.device).view(-1, 1) | |
| return result | |
| def __sampleDelta(self, B, C, k): | |
| ''' | |
| Sample k-pixel perturbations for B images. Each pixel is assigned a | |
| random corner in the C-dimensional cube defined by self.img_range. | |
| ''' | |
| fac = self.img_range[1] - self.img_range[0] | |
| return self.img_range[0] + fac * torch.randint(0, 1, [B, k, C], | |
| dtype=torch.float, | |
| device=self.device) | |
| def __alphaSchedule(self, iteration): | |
| ''' | |
| Update number of pixels to perturb based in the current iteration. | |
| ''' | |
| iteration = int(iteration / self.n_queries * 10000) | |
| factors = [1, 2, 4, 5, 6, 8, 10, 12, 15, 20] | |
| alpha_schedule = [10, 50, 200, 500, 1000, 2000, 4000, 6000, 8000] | |
| idx = bisect.bisect_left(alpha_schedule, iteration) | |
| return self.alpha_init / factors[idx] | |
| def __lossfn(self, x): | |
| ''' | |
| Compute the loss depending on self.targeted. | |
| ''' | |
| return self.model(x).sum() if self.targeted else -self.model(x).sum() | |