Spaces:
Sleeping
Sleeping
| import math | |
| import random | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from util.misc import (NestedTensor, inverse_sigmoid, | |
| nested_tensor_from_tensor_list) | |
| from .blip2_decoder import BLIP2Decoder | |
| from .deformable_detr.backbone import build_backbone | |
| from .deformable_detr.deformable_detr import DeformableDETR | |
| from .transformer import build_ov_transformer | |
| class ContextDET(DeformableDETR): | |
| def __init__(self, backbone, transformer, num_classes, num_queries, num_feature_levels, | |
| aux_loss=True, with_box_refine=False, two_stage=False, llm_decoder=None): | |
| super().__init__(backbone, transformer, num_classes, num_queries, num_feature_levels, | |
| aux_loss, with_box_refine, two_stage) | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.llm_decoder = llm_decoder | |
| hidden_dim = transformer.d_model | |
| out_size = self.llm_decoder.model.opt_proj.out_features | |
| self.llm_proj = nn.Linear(out_size, hidden_dim, device=self.device) | |
| self.start_end_proj = nn.Linear(hidden_dim, 2) | |
| for layer in [self.llm_proj, self.start_end_proj]: | |
| nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu') | |
| nn.init.zeros_(layer.bias) | |
| # word_embed_proj_dim = llm_decoder.model.opt_model.config.word_embed_proj_dim | |
| vocab_size = llm_decoder.model.opt_model.config.vocab_size | |
| self.fc_logits = nn.Linear(hidden_dim, vocab_size) | |
| def forward(self, samples, blip2_samples, mask_infos=None, task_button=None, threshold=0.3): | |
| logits, hidden_states, input_ids, output_text = self.llm_decoder.model.forward( | |
| blip2_samples, task_button=task_button) | |
| hidden_states = hidden_states.detach() | |
| hidden_states = self.llm_proj(hidden_states) | |
| if not isinstance(samples, NestedTensor): | |
| samples = nested_tensor_from_tensor_list(samples) | |
| features, pos = self.backbone(samples) | |
| srcs = [] | |
| masks = [] | |
| for l, feat in enumerate(features): | |
| src, mask = feat.decompose() | |
| srcs.append(self.input_proj[l](src)) | |
| masks.append(mask) | |
| assert mask is not None | |
| if self.num_feature_levels > len(srcs): | |
| _len_srcs = len(srcs) | |
| for l in range(_len_srcs, self.num_feature_levels): | |
| if l == _len_srcs: | |
| src = self.input_proj[l](features[-1].tensors) | |
| else: | |
| src = self.input_proj[l](srcs[-1]) | |
| m = samples.mask | |
| mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] | |
| pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) | |
| srcs.append(src) | |
| masks.append(mask) | |
| pos.append(pos_l) | |
| out = {} | |
| start_end_proj = self.start_end_proj(hidden_states) | |
| out['pred_mlm_logits'] = self.fc_logits(hidden_states) | |
| out['pred_start'] = start_end_proj[:, :, 0:1] | |
| out['pred_end'] = start_end_proj[:, :, 1:2] | |
| out['output_text'] = output_text | |
| if self.training: | |
| k = min([len(mask_info) for mask_info in mask_infos]) | |
| k = min(k, 2) | |
| select_ids = [random.sample(mask_info.keys(), k) for mask_info in mask_infos] | |
| # select_ids = [random.choices(list(mask_info.keys()), k=4) for mask_info in mask_infos] | |
| llm_feat = [] | |
| for b in range(len(select_ids)): | |
| llm_feat_b = [] | |
| hidden_states_b = hidden_states[b, :, :] | |
| for start, end in select_ids[b]: | |
| llm_feat_b.append(hidden_states_b[start: end + 1].mean(dim=0, keepdim=True)) | |
| llm_feat.append(torch.cat(llm_feat_b)[None]) | |
| llm_feat = torch.cat(llm_feat) | |
| query_embeds = None | |
| if not self.two_stage: | |
| query_embeds = self.query_embed.weight | |
| hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact, anchors = ( | |
| self.transformer(srcs, masks, pos, query_embeds, llm_feat, k) | |
| ) | |
| outputs_classes = [] | |
| outputs_coords = [] | |
| for lvl in range(hs.shape[0]): | |
| if lvl == 0: | |
| reference = init_reference | |
| else: | |
| reference = inter_references[lvl - 1] | |
| reference = inverse_sigmoid(reference) | |
| outputs_class = self.class_embed[lvl](hs[lvl]) | |
| tmp = self.bbox_embed[lvl](hs[lvl]) | |
| if reference.shape[-1] == 4: | |
| tmp += reference | |
| else: | |
| assert reference.shape[-1] == 2 | |
| tmp[..., :2] += reference | |
| outputs_coord = tmp.sigmoid() | |
| outputs_classes.append(outputs_class) | |
| outputs_coords.append(outputs_coord) | |
| outputs_class = torch.stack(outputs_classes) | |
| outputs_coord = torch.stack(outputs_coords) | |
| out.update({'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1], | |
| 'init_reference': init_reference}) | |
| out['select_ids'] = select_ids | |
| if self.aux_loss: | |
| out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord) | |
| for temp in out["aux_outputs"]: | |
| temp["select_ids"] = select_ids | |
| if self.two_stage: | |
| enc_outputs_coord = enc_outputs_coord_unact.sigmoid() | |
| out['enc_outputs'] = { | |
| 'pred_logits': enc_outputs_class, | |
| 'pred_boxes': enc_outputs_coord, | |
| 'anchors': anchors, | |
| } | |
| else: | |
| bs = len(samples.tensors) | |
| mask_infos_pred = [{} for _ in range(bs)] | |
| llm_feat = [] | |
| tokenizer = self.llm_decoder.model.opt_tokenizer | |
| if mask_infos is None: | |
| if task_button == 'Cloze Test': | |
| mask_infos = [] | |
| output_texts = [] | |
| for b in range(bs): | |
| mask_infos_b = {} | |
| output_texts_b = [] | |
| for ind, token in enumerate(input_ids[b]): | |
| if token == tokenizer.mask_token_id: | |
| mask_infos_b[(ind, ind)] = '' | |
| pred_token = out['pred_mlm_logits'][b, ind:ind + 1, :] | |
| pred_token = pred_token.argmax(1).item() | |
| output_texts_b.append( pred_token ) | |
| output_texts_b.append( 1437 ) | |
| input_ids[b, ind: ind + 1] = pred_token | |
| else: | |
| output_texts_b.append( token.item() ) | |
| mask_infos.append(mask_infos_b) | |
| output_texts.append(tokenizer.decode(output_texts_b[1:])) | |
| out['output_text'] = output_texts | |
| else: | |
| mask_infos = [] | |
| for b in range(bs): | |
| starts = (out['pred_start'][b, :, 0].sigmoid() > threshold).nonzero().squeeze(1) | |
| ends = (out['pred_end'][b, :, 0].sigmoid() > threshold).nonzero().squeeze(1) | |
| if len(starts) == 0: | |
| starts = out['pred_start'][b, :].argmax(0) | |
| if len(ends) == 0: | |
| ends = out['pred_end'][b, :].argmax(0) | |
| mask_infos_b = {} | |
| for start, end in zip(starts, ends): | |
| mask_infos_b[(int(start), int(end))] = '' | |
| mask_infos.append(mask_infos_b) | |
| for b in range(bs): | |
| llm_feat_b = [] | |
| hidden_states_b = hidden_states[b, :, :] | |
| for start, end in mask_infos[b].keys(): | |
| llm_feat_b.append(hidden_states_b[start: end + 1].mean(dim=0, keepdim=True)) | |
| pred_name = tokenizer.decode(input_ids[b, start: end + 1]).strip() | |
| mask_infos_pred[b][(int(start), int(end))] = pred_name | |
| llm_feat.append(torch.cat(llm_feat_b)[None]) | |
| out['mask_infos_pred'] = mask_infos_pred | |
| query_embeds = None | |
| if not self.two_stage: | |
| query_embeds = self.query_embed.weight | |
| outputs_classes_list = [] | |
| outputs_coords_list = [] | |
| for b in range(bs): | |
| srcs_b = [i[b: b + 1] for i in srcs] | |
| masks_b = [i[b: b + 1] for i in masks] | |
| pos_b = [i[b: b + 1] for i in pos] | |
| k = len(mask_infos[b]) | |
| if k == 0: | |
| outputs_classes_list.append(torch.zeros(0, 2).to(self.device)) | |
| outputs_coords_list.append(torch.zeros(0, 4).to(self.device)) | |
| continue | |
| num_repeat = math.ceil(k / 4) | |
| outputs_classes = [] | |
| outputs_coords = [] | |
| for ind in range(num_repeat): | |
| llm_feat_b = llm_feat[b][:, ind * 4: (ind + 1) * 4] | |
| hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact, anchors = ( | |
| self.transformer(srcs_b, masks_b, pos_b, query_embeds, llm_feat_b, llm_feat_b.shape[1]) | |
| ) | |
| lvl = hs.shape[0] - 1 | |
| reference = inter_references[lvl - 1] | |
| reference = inverse_sigmoid(reference) | |
| outputs_class = self.class_embed[lvl](hs[lvl]) | |
| tmp = self.bbox_embed[lvl](hs[lvl]) | |
| if reference.shape[-1] == 4: | |
| tmp += reference | |
| else: | |
| assert reference.shape[-1] == 2 | |
| tmp[..., :2] += reference | |
| outputs_coord = tmp.sigmoid() | |
| outputs_classes.append(outputs_class.flatten(0, 1)) | |
| outputs_coords.append(outputs_coord.flatten(0, 1)) | |
| outputs_classes = torch.cat(outputs_classes)[None] | |
| outputs_coords = torch.cat(outputs_coords)[None] | |
| outputs_classes_list.append(outputs_classes) | |
| outputs_coords_list.append(outputs_coords) | |
| out.update({'pred_logits': outputs_classes_list, | |
| 'pred_boxes': outputs_coords_list}) | |
| return out |