Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # Modified by Jeffrey Ouyang-Zhang | |
| from typing import List | |
| import torch | |
| import torch.nn as nn | |
| from util.box_ops import (box_cxcywh_to_xyxy, box_iou, box_xyxy_to_cxcywh, | |
| generalized_box_iou) | |
| # from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/layers/wrappers.py#L100 | |
| def nonzero_tuple(x): | |
| """ | |
| A 'as_tuple=True' version of torch.nonzero to support torchscript. | |
| because of https://github.com/pytorch/pytorch/issues/38718 | |
| """ | |
| if torch.jit.is_scripting(): | |
| if x.dim() == 0: | |
| return x.unsqueeze(0).nonzero().unbind(1) | |
| return x.nonzero().unbind(1) | |
| else: | |
| return x.nonzero(as_tuple=True) | |
| # from https://github.com/facebookresearch/detectron2/blob/9921a2caa585d4fa66c4b534b6fab6e74d89b582/detectron2/modeling/matcher.py#L9 | |
| class Matcher(object): | |
| """ | |
| This class assigns to each predicted "element" (e.g., a box) a ground-truth | |
| element. Each predicted element will have exactly zero or one matches; each | |
| ground-truth element may be matched to zero or more predicted elements. | |
| The matching is determined by the MxN match_quality_matrix, that characterizes | |
| how well each (ground-truth, prediction)-pair match each other. For example, | |
| if the elements are boxes, this matrix may contain box intersection-over-union | |
| overlap values. | |
| The matcher returns (a) a vector of length N containing the index of the | |
| ground-truth element m in [0, M) that matches to prediction n in [0, N). | |
| (b) a vector of length N containing the labels for each prediction. | |
| """ | |
| def __init__( | |
| self, thresholds: List[float], labels: List[int], allow_low_quality_matches: bool = False | |
| ): | |
| """ | |
| Args: | |
| thresholds (list): a list of thresholds used to stratify predictions | |
| into levels. | |
| labels (list): a list of values to label predictions belonging at | |
| each level. A label can be one of {-1, 0, 1} signifying | |
| {ignore, negative class, positive class}, respectively. | |
| allow_low_quality_matches (bool): if True, produce additional matches | |
| for predictions with maximum match quality lower than high_threshold. | |
| See set_low_quality_matches_ for more details. | |
| For example, | |
| thresholds = [0.3, 0.5] | |
| labels = [0, -1, 1] | |
| All predictions with iou < 0.3 will be marked with 0 and | |
| thus will be considered as false positives while training. | |
| All predictions with 0.3 <= iou < 0.5 will be marked with -1 and | |
| thus will be ignored. | |
| All predictions with 0.5 <= iou will be marked with 1 and | |
| thus will be considered as true positives. | |
| """ | |
| # Add -inf and +inf to first and last position in thresholds | |
| thresholds = thresholds[:] | |
| assert thresholds[0] > 0 | |
| thresholds.insert(0, -float("inf")) | |
| thresholds.append(float("inf")) | |
| # Currently torchscript does not support all + generator | |
| assert all([low <= high for (low, high) in zip(thresholds[:-1], thresholds[1:])]), thresholds | |
| assert all([l in [-1, 0, 1] for l in labels]) | |
| assert len(labels) == len(thresholds) - 1 | |
| self.thresholds = thresholds | |
| self.labels = labels | |
| self.allow_low_quality_matches = allow_low_quality_matches | |
| def __call__(self, match_quality_matrix): | |
| """ | |
| Args: | |
| match_quality_matrix (Tensor[float]): an MxN tensor, containing the | |
| pairwise quality between M ground-truth elements and N predicted | |
| elements. All elements must be >= 0 (due to the us of `torch.nonzero` | |
| for selecting indices in :meth:`set_low_quality_matches_`). | |
| Returns: | |
| matches (Tensor[int64]): a vector of length N, where matches[i] is a matched | |
| ground-truth index in [0, M) | |
| match_labels (Tensor[int8]): a vector of length N, where pred_labels[i] indicates | |
| whether a prediction is a true or false positive or ignored | |
| """ | |
| assert match_quality_matrix.dim() == 2 | |
| if match_quality_matrix.numel() == 0: | |
| default_matches = match_quality_matrix.new_full( | |
| (match_quality_matrix.size(1),), 0, dtype=torch.int64 | |
| ) | |
| # When no gt boxes exist, we define IOU = 0 and therefore set labels | |
| # to `self.labels[0]`, which usually defaults to background class 0 | |
| # To choose to ignore instead, can make labels=[-1,0,-1,1] + set appropriate thresholds | |
| default_match_labels = match_quality_matrix.new_full( | |
| (match_quality_matrix.size(1),), self.labels[0], dtype=torch.int8 | |
| ) | |
| return default_matches, default_match_labels | |
| assert torch.all(match_quality_matrix >= 0) | |
| # match_quality_matrix is M (gt) x N (predicted) | |
| # Max over gt elements (dim 0) to find best gt candidate for each prediction | |
| matched_vals, matches = match_quality_matrix.max(dim=0) | |
| match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8) | |
| for (l, low, high) in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]): | |
| low_high = (matched_vals >= low) & (matched_vals < high) | |
| match_labels[low_high] = l | |
| if self.allow_low_quality_matches: | |
| self.set_low_quality_matches_(match_labels, match_quality_matrix) | |
| return matches, match_labels | |
| def set_low_quality_matches_(self, match_labels, match_quality_matrix): | |
| """ | |
| Produce additional matches for predictions that have only low-quality matches. | |
| Specifically, for each ground-truth G find the set of predictions that have | |
| maximum overlap with it (including ties); for each prediction in that set, if | |
| it is unmatched, then match it to the ground-truth G. | |
| This function implements the RPN assignment case (i) in Sec. 3.1.2 of | |
| :paper:`Faster R-CNN`. | |
| """ | |
| # For each gt, find the prediction with which it has highest quality | |
| highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) | |
| # Find the highest quality match available, even if it is low, including ties. | |
| # Note that the matches qualities must be positive due to the use of | |
| # `torch.nonzero`. | |
| _, pred_inds_with_highest_quality = nonzero_tuple( | |
| match_quality_matrix == highest_quality_foreach_gt[:, None] | |
| ) | |
| # If an anchor was labeled positive only due to a low-quality match | |
| # with gt_A, but it has larger overlap with gt_B, it's matched index will still be gt_B. | |
| # This follows the implementation in Detectron, and is found to have no significant impact. | |
| match_labels[pred_inds_with_highest_quality] = 1 | |
| # from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/modeling/sampling.py#L9 | |
| def subsample_labels( | |
| labels: torch.Tensor, num_samples: int, positive_fraction: float, bg_label: int | |
| ): | |
| """ | |
| Return `num_samples` (or fewer, if not enough found) | |
| random samples from `labels` which is a mixture of positives & negatives. | |
| It will try to return as many positives as possible without | |
| exceeding `positive_fraction * num_samples`, and then try to | |
| fill the remaining slots with negatives. | |
| Args: | |
| labels (Tensor): (N, ) label vector with values: | |
| * -1: ignore | |
| * bg_label: background ("negative") class | |
| * otherwise: one or more foreground ("positive") classes | |
| num_samples (int): The total number of labels with value >= 0 to return. | |
| Values that are not sampled will be filled with -1 (ignore). | |
| positive_fraction (float): The number of subsampled labels with values > 0 | |
| is `min(num_positives, int(positive_fraction * num_samples))`. The number | |
| of negatives sampled is `min(num_negatives, num_samples - num_positives_sampled)`. | |
| In order words, if there are not enough positives, the sample is filled with | |
| negatives. If there are also not enough negatives, then as many elements are | |
| sampled as is possible. | |
| bg_label (int): label index of background ("negative") class. | |
| Returns: | |
| pos_idx, neg_idx (Tensor): | |
| 1D vector of indices. The total length of both is `num_samples` or fewer. | |
| """ | |
| positive = nonzero_tuple((labels != -1) & (labels != bg_label))[0] | |
| negative = nonzero_tuple(labels == bg_label)[0] | |
| num_pos = int(num_samples * positive_fraction) | |
| # protect against not enough positive examples | |
| num_pos = min(positive.numel(), num_pos) | |
| num_neg = num_samples - num_pos | |
| # protect against not enough negative examples | |
| num_neg = min(negative.numel(), num_neg) | |
| # randomly select positive and negative examples | |
| perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos] | |
| perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg] | |
| pos_idx = positive[perm1] | |
| neg_idx = negative[perm2] | |
| return pos_idx, neg_idx | |
| def sample_topk_per_gt(pr_inds, gt_inds, iou, k): | |
| if len(gt_inds) == 0: | |
| return pr_inds, gt_inds | |
| # find topk matches for each gt | |
| gt_inds2, counts = gt_inds.unique(return_counts=True) | |
| scores, pr_inds2 = iou[gt_inds2].topk(k, dim=1) | |
| gt_inds2 = gt_inds2[:,None].repeat(1, k) | |
| # filter to as many matches that gt has | |
| pr_inds3 = torch.cat([pr[:c] for c, pr in zip(counts, pr_inds2)]) | |
| gt_inds3 = torch.cat([gt[:c] for c, gt in zip(counts, gt_inds2)]) | |
| return pr_inds3, gt_inds3 | |
| # modified from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/modeling/roi_heads/roi_heads.py#L123 | |
| class Stage2Assigner(nn.Module): | |
| def __init__(self, num_queries, max_k=4): | |
| super().__init__() | |
| self.positive_fraction = 0.25 | |
| self.bg_label = 400 # number > 91 to filter out later | |
| self.batch_size_per_image = num_queries | |
| self.proposal_matcher = Matcher(thresholds=[0.6], labels=[0, 1], allow_low_quality_matches=True) | |
| self.k = max_k | |
| def _sample_proposals( | |
| self, matched_idxs: torch.Tensor, matched_labels: torch.Tensor, gt_classes: torch.Tensor | |
| ): | |
| """ | |
| Based on the matching between N proposals and M groundtruth, | |
| sample the proposals and set their classification labels. | |
| Args: | |
| matched_idxs (Tensor): a vector of length N, each is the best-matched | |
| gt index in [0, M) for each proposal. | |
| matched_labels (Tensor): a vector of length N, the matcher's label | |
| (one of cfg.MODEL.ROI_HEADS.IOU_LABELS) for each proposal. | |
| gt_classes (Tensor): a vector of length M. | |
| Returns: | |
| Tensor: a vector of indices of sampled proposals. Each is in [0, N). | |
| Tensor: a vector of the same length, the classification label for | |
| each sampled proposal. Each sample is labeled as either a category in | |
| [0, num_classes) or the background (num_classes). | |
| """ | |
| has_gt = gt_classes.numel() > 0 | |
| # Get the corresponding GT for each proposal | |
| if has_gt: | |
| gt_classes = gt_classes[matched_idxs] | |
| # Label unmatched proposals (0 label from matcher) as background (label=num_classes) | |
| gt_classes[matched_labels == 0] = self.bg_label | |
| # Label ignore proposals (-1 label) | |
| gt_classes[matched_labels == -1] = -1 | |
| else: | |
| gt_classes = torch.zeros_like(matched_idxs) + self.bg_label | |
| sampled_fg_idxs, sampled_bg_idxs = subsample_labels( | |
| gt_classes, self.batch_size_per_image, self.positive_fraction, self.bg_label | |
| ) | |
| sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0) | |
| return sampled_idxs, gt_classes[sampled_idxs] | |
| def forward(self, outputs, targets, return_cost_matrix=False): | |
| # COCO categories are from 1 to 90. They set num_classes=91 and apply sigmoid. | |
| bs = len(targets) | |
| indices = [] | |
| ious = [] | |
| for b in range(bs): | |
| iou, _ = box_iou( | |
| box_cxcywh_to_xyxy(targets[b]['boxes']), | |
| box_cxcywh_to_xyxy(outputs['init_reference'][b].detach()), | |
| ) | |
| matched_idxs, matched_labels = self.proposal_matcher(iou) # proposal_id -> highest_iou_gt_id, proposal_id -> [1 if iou > 0.6, 0 ow] | |
| sampled_idxs, sampled_gt_classes = self._sample_proposals( # list of sampled proposal_ids, sampled_id -> [0, num_classes)+[bg_label] | |
| matched_idxs, matched_labels, targets[b]['labels'] | |
| ) | |
| pos_pr_inds = sampled_idxs[sampled_gt_classes != self.bg_label] | |
| pos_gt_inds = matched_idxs[pos_pr_inds] | |
| pos_pr_inds, pos_gt_inds = self.postprocess_indices(pos_pr_inds, pos_gt_inds, iou) | |
| indices.append((pos_pr_inds, pos_gt_inds)) | |
| ious.append(iou) | |
| if return_cost_matrix: | |
| return indices, ious | |
| return indices | |
| def postprocess_indices(self, pr_inds, gt_inds, iou): | |
| return sample_topk_per_gt(pr_inds, gt_inds, iou, self.k) | |
| # modified from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/modeling/proposal_generator/rpn.py#L181 | |
| class Stage1Assigner(nn.Module): | |
| def __init__(self, t_low=0.3, t_high=0.7, max_k=4): | |
| super().__init__() | |
| self.positive_fraction = 0.5 | |
| self.batch_size_per_image = 256 | |
| self.k = max_k | |
| self.t_low = t_low | |
| self.t_high = t_high | |
| self.anchor_matcher = Matcher(thresholds=[t_low, t_high], labels=[0, -1, 1], allow_low_quality_matches=True) | |
| def _subsample_labels(self, label): | |
| """ | |
| Randomly sample a subset of positive and negative examples, and overwrite | |
| the label vector to the ignore value (-1) for all elements that are not | |
| included in the sample. | |
| Args: | |
| labels (Tensor): a vector of -1, 0, 1. Will be modified in-place and returned. | |
| """ | |
| pos_idx, neg_idx = subsample_labels( | |
| label, self.batch_size_per_image, self.positive_fraction, 0 | |
| ) | |
| # Fill with the ignore label (-1), then set positive and negative labels | |
| label.fill_(-1) | |
| label.scatter_(0, pos_idx, 1) | |
| label.scatter_(0, neg_idx, 0) | |
| return label | |
| def forward(self, outputs, targets): | |
| bs = len(targets) | |
| indices = [] | |
| for b in range(bs): | |
| anchors = outputs['anchors'][b] | |
| if len(targets[b]['boxes']) == 0: | |
| indices.append((torch.tensor([], dtype=torch.long, device=anchors.device), | |
| torch.tensor([], dtype=torch.long, device=anchors.device))) | |
| continue | |
| iou, _ = box_iou( | |
| box_cxcywh_to_xyxy(targets[b]['boxes']), | |
| box_cxcywh_to_xyxy(anchors), | |
| ) | |
| matched_idxs, matched_labels = self.anchor_matcher(iou) # proposal_id -> highest_iou_gt_id, proposal_id -> [1 if iou > 0.7, 0 if iou < 0.3, -1 ow] | |
| matched_labels = self._subsample_labels(matched_labels) | |
| all_pr_inds = torch.arange(len(anchors)) | |
| pos_pr_inds = all_pr_inds[matched_labels == 1] | |
| pos_gt_inds = matched_idxs[pos_pr_inds] | |
| pos_ious = iou[pos_gt_inds, pos_pr_inds] | |
| pos_pr_inds, pos_gt_inds = self.postprocess_indices(pos_pr_inds, pos_gt_inds, iou) | |
| pos_pr_inds, pos_gt_inds = pos_pr_inds.to(anchors.device), pos_gt_inds.to(anchors.device) | |
| indices.append((pos_pr_inds, pos_gt_inds)) | |
| return indices | |
| def postprocess_indices(self, pr_inds, gt_inds, iou): | |
| return sample_topk_per_gt(pr_inds, gt_inds, iou, self.k) | |