Spaces:
Sleeping
Sleeping
| import torch | |
| from torchvision.ops.boxes import batched_nms | |
| from util.box_ops import box_cxcywh_to_xyxy | |
| from .deformable_detr.deformable_transformer import DeformableTransformer | |
| class OVTransformer(DeformableTransformer): | |
| def __init__(self, d_model=256, nhead=8, | |
| num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1, | |
| activation="relu", return_intermediate_dec=False, | |
| num_feature_levels=4, dec_n_points=4, enc_n_points=4, | |
| two_stage=False, two_stage_num_proposals=300, | |
| assign_first_stage=False): | |
| super().__init__(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout, | |
| activation, return_intermediate_dec, num_feature_levels, dec_n_points, enc_n_points, | |
| two_stage, two_stage_num_proposals, assign_first_stage) | |
| def forward(self, srcs, masks, pos_embeds, query_embed=None, llm_feat=None, num_patch=1): | |
| assert self.two_stage or query_embed is not None | |
| # prepare input for encoder | |
| src_flatten = [] | |
| mask_flatten = [] | |
| lvl_pos_embed_flatten = [] | |
| spatial_shapes = [] | |
| for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): | |
| bs, c, h, w = src.shape | |
| spatial_shape = (h, w) | |
| spatial_shapes.append(spatial_shape) | |
| src = src.flatten(2).transpose(1, 2) | |
| mask = mask.flatten(1) | |
| pos_embed = pos_embed.flatten(2).transpose(1, 2) | |
| lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) | |
| lvl_pos_embed_flatten.append(lvl_pos_embed) | |
| src_flatten.append(src) | |
| mask_flatten.append(mask) | |
| src_flatten = torch.cat(src_flatten, 1) | |
| mask_flatten = torch.cat(mask_flatten, 1) | |
| lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) | |
| spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) | |
| level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) | |
| valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) | |
| # encoder | |
| memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, | |
| lvl_pos_embed_flatten, mask_flatten) | |
| # prepare input for decoder | |
| bs, _, c = memory.shape | |
| if self.two_stage: | |
| output_memory, output_proposals, level_ids = \ | |
| self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes) | |
| # hack implementation for two-stage Deformable DETR | |
| enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory) | |
| enc_outputs_coord_unact = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals | |
| topk = self.two_stage_num_proposals | |
| proposal_logit = enc_outputs_class[..., 0] | |
| if self.assign_first_stage: | |
| proposal_boxes = box_cxcywh_to_xyxy(enc_outputs_coord_unact.sigmoid().float()).clamp(0, 1) | |
| topk_proposals = [] | |
| for b in range(bs): | |
| prop_boxes_b = proposal_boxes[b] | |
| prop_logits_b = proposal_logit[b] | |
| # pre-nms per-level topk | |
| pre_nms_topk = 1000 | |
| pre_nms_inds = [] | |
| for lvl in range(len(spatial_shapes)): | |
| lvl_mask = level_ids == lvl | |
| pre_nms_inds.append(torch.topk(prop_logits_b.sigmoid() * lvl_mask, pre_nms_topk)[1]) | |
| pre_nms_inds = torch.cat(pre_nms_inds) | |
| # nms on topk indices | |
| post_nms_inds = batched_nms(prop_boxes_b[pre_nms_inds], | |
| prop_logits_b[pre_nms_inds], | |
| level_ids[pre_nms_inds], 0.9) | |
| keep_inds = pre_nms_inds[post_nms_inds] | |
| if len(keep_inds) < self.two_stage_num_proposals: | |
| print(f'[WARNING] nms proposals ({len(keep_inds)}) < {self.two_stage_num_proposals}') | |
| keep_inds = torch.topk(proposal_logit[b], topk)[1] | |
| # keep top Q/L indices for L levels | |
| q_per_l = topk // len(spatial_shapes) | |
| level_shapes = torch.arange(len(spatial_shapes), device=level_ids.device)[:, None] | |
| is_level_ordered = level_ids[keep_inds][None] == level_shapes | |
| keep_inds_mask = is_level_ordered & (is_level_ordered.cumsum(1) <= q_per_l) # LS | |
| keep_inds_mask = keep_inds_mask.any(0) # S | |
| # pad to Q indices (might let ones filtered from pre-nms sneak by... | |
| # unlikely because we pick high conf anyways) | |
| if keep_inds_mask.sum() < topk: | |
| num_to_add = topk - keep_inds_mask.sum() | |
| pad_inds = (~keep_inds_mask).nonzero()[:num_to_add] | |
| keep_inds_mask[pad_inds] = True | |
| # index | |
| keep_inds_topk = keep_inds[keep_inds_mask] | |
| topk_proposals.append(keep_inds_topk) | |
| topk_proposals = torch.stack(topk_proposals) | |
| else: | |
| topk_proposals = torch.topk(proposal_logit, topk, dim=1)[1] | |
| topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) | |
| topk_coords_unact = topk_coords_unact.detach() | |
| reference_points = topk_coords_unact.sigmoid() | |
| init_reference_out = reference_points | |
| pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))) | |
| query_embed, tgt = torch.split(pos_trans_out, c, dim=2) | |
| num_queries = query_embed.shape[1] | |
| query_embed = query_embed.repeat(1, num_patch, 1) | |
| tgt = tgt.repeat(1, num_patch, 1) | |
| topk_feats = torch.stack([output_memory[b][topk_proposals[b]] for b in range(bs)]).detach() | |
| topk_feats = topk_feats.repeat(1, num_patch, 1) | |
| tgt = tgt + self.pix_trans_norm(self.pix_trans(topk_feats)) | |
| reference_points = reference_points.repeat(1, num_patch, 1) | |
| init_reference_out = init_reference_out.repeat(1, num_patch, 1) | |
| llm_feat = llm_feat.repeat_interleave(num_queries, 1) | |
| tgt = tgt + llm_feat | |
| else: | |
| raise NotImplementedError | |
| query_embed, tgt = torch.split(query_embed, c, dim=1) | |
| query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1) | |
| tgt = tgt.unsqueeze(0).expand(bs, -1, -1) | |
| reference_points = self.reference_points(query_embed).sigmoid() | |
| init_reference_out = reference_points | |
| # decoder mask | |
| decoder_mask = ( | |
| torch.ones( | |
| num_queries * num_patch, | |
| num_queries * num_patch, | |
| device=query_embed.device, | |
| ) * float("-inf") | |
| ) | |
| for i in range(num_patch): | |
| decoder_mask[ | |
| i * num_queries : (i + 1) * num_queries, | |
| i * num_queries : (i + 1) * num_queries, | |
| ] = 0 | |
| # decoder | |
| hs, inter_references = self.decoder(tgt, reference_points, memory, | |
| spatial_shapes, level_start_index, valid_ratios, | |
| query_embed, mask_flatten, tgt_mask=decoder_mask) | |
| inter_references_out = inter_references | |
| if self.two_stage: | |
| return (hs, | |
| init_reference_out, | |
| inter_references_out, | |
| enc_outputs_class, | |
| enc_outputs_coord_unact, | |
| output_proposals.sigmoid()) | |
| return hs, init_reference_out, inter_references_out, None, None, None | |
| def build_ov_transformer(args): | |
| return OVTransformer( | |
| d_model=args.hidden_dim, | |
| nhead=args.nheads, | |
| num_encoder_layers=args.enc_layers, | |
| num_decoder_layers=args.dec_layers, | |
| dim_feedforward=args.dim_feedforward, | |
| dropout=args.dropout, | |
| activation="relu", | |
| return_intermediate_dec=True, | |
| num_feature_levels=args.num_feature_levels, | |
| dec_n_points=args.dec_n_points, | |
| enc_n_points=args.enc_n_points, | |
| two_stage=args.two_stage, | |
| two_stage_num_proposals=args.num_queries, | |
| assign_first_stage=args.assign_first_stage) | |