Spaces:
Sleeping
Sleeping
| import argparse | |
| import io | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms as T | |
| from PIL import Image | |
| from models.blip2_decoder import BLIP2Decoder | |
| from models.deformable_detr.backbone import build_backbone | |
| from models.contextdet_blip2 import ContextDET | |
| from models.post_process import CondNMSPostProcess | |
| from models.transformer import build_ov_transformer | |
| from util.misc import nested_tensor_from_tensor_list | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--device', type=str, default='cpu') | |
| parser.add_argument('--lr_backbone_names', default=["backbone.0"], type=str, nargs='+') | |
| parser.add_argument('--lr_backbone', default=2e-5, type=float) | |
| parser.add_argument('--with_box_refine', default=True, action='store_false') | |
| parser.add_argument('--two_stage', default=True, action='store_false') | |
| # * Backbone | |
| parser.add_argument('--backbone', default='resnet50', type=str, | |
| help="Name of the convolutional backbone to use") | |
| parser.add_argument('--dilation', action='store_true', | |
| help="If true, we replace stride with dilation in the last convolutional block (DC5)") | |
| parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), | |
| help="Type of positional embedding to use on top of the image features") | |
| parser.add_argument('--position_embedding_scale', default=2 * np.pi, type=float, | |
| help="position / size * scale") | |
| parser.add_argument('--num_feature_levels', default=5, type=int, help='number of feature levels') | |
| # * Transformer | |
| parser.add_argument('--enc_layers', default=6, type=int, | |
| help="Number of encoding layers in the transformer") | |
| parser.add_argument('--dec_layers', default=6, type=int, | |
| help="Number of decoding layers in the transformer") | |
| parser.add_argument('--dim_feedforward', default=2048, type=int, | |
| help="Intermediate size of the feedforward layers in the transformer blocks") | |
| parser.add_argument('--hidden_dim', default=256, type=int, | |
| help="Size of the embeddings (dimension of the transformer)") | |
| parser.add_argument('--dropout', default=0.0, type=float, | |
| help="Dropout applied in the transformer") | |
| parser.add_argument('--nheads', default=8, type=int, | |
| help="Number of attention heads inside the transformer's attentions") | |
| parser.add_argument('--num_queries', default=900, type=int, | |
| help="Number of query slots") | |
| parser.add_argument('--dec_n_points', default=4, type=int) | |
| parser.add_argument('--enc_n_points', default=4, type=int) | |
| # * Segmentation | |
| parser.add_argument('--masks', action='store_true', | |
| help="Train segmentation head if the flag is provided") | |
| parser.add_argument('--assign_first_stage', default=True, action='store_false') | |
| parser.add_argument('--assign_second_stage', default=True, action='store_false') | |
| parser.add_argument('--name', default='ov') | |
| parser.add_argument('--llm_name', default='bert-base-cased') | |
| parser.add_argument('--resume', default='', type=str) | |
| return parser.parse_args() | |
| COLORS = [ | |
| [0.000, 0.447, 0.741], | |
| [0.850, 0.325, 0.098], | |
| [0.929, 0.694, 0.125], | |
| [0.494, 0.184, 0.556], | |
| [0.466, 0.674, 0.188], | |
| [0.301, 0.745, 0.933] | |
| ] | |
| def fig2img(fig): | |
| buf = io.BytesIO() | |
| fig.savefig(buf) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| return img | |
| def visualize_prediction(pil_img, output_dict, threshold=0.7): | |
| keep = output_dict["scores"] > threshold | |
| boxes = output_dict["boxes"][keep].tolist() | |
| scores = output_dict["scores"][keep].tolist() | |
| keep_list = keep.nonzero().squeeze(1).numpy().tolist() | |
| labels = [output_dict["names"][i] for i in keep_list] | |
| plt.figure(figsize=(12.8, 8)) | |
| plt.imshow(pil_img) | |
| ax = plt.gca() | |
| colors = COLORS * 100 | |
| for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors): | |
| ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=3)) | |
| ax.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5)) | |
| plt.axis("off") | |
| return fig2img(plt.gcf()) | |
| class ContextDetDemo(): | |
| def __init__(self, resume): | |
| self.transform = T.Compose([ | |
| T.Resize(640), | |
| T.ToTensor(), | |
| T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| args = parse_args() | |
| args.llm_name = 'caption_coco_opt2.7b' | |
| args.resume = resume | |
| args.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| num_classes = 2 | |
| device = torch.device(args.device) | |
| backbone = build_backbone(args) | |
| transformer = build_ov_transformer(args) | |
| llm_decoder = BLIP2Decoder(args.llm_name) | |
| model = ContextDET( | |
| backbone, | |
| transformer, | |
| num_classes=num_classes, | |
| num_queries=args.num_queries, | |
| num_feature_levels=args.num_feature_levels, | |
| aux_loss=False, | |
| with_box_refine=args.with_box_refine, | |
| two_stage=args.two_stage, | |
| llm_decoder=llm_decoder, | |
| ) | |
| model = model.to(device) | |
| checkpoint = torch.load(args.resume, map_location='cpu') | |
| missing_keys, unexpected_keys = model.load_state_dict(checkpoint['model'], strict=False) | |
| if len(missing_keys) > 0: | |
| print('Missing Keys: {}'.format(missing_keys)) | |
| if len(unexpected_keys) > 0: | |
| print('Unexpected Keys: {}'.format(unexpected_keys)) | |
| postprocessor = CondNMSPostProcess(args.num_queries) | |
| self.model = model | |
| self.model.eval() | |
| self.postprocessor = postprocessor | |
| def forward(self, image, text, task_button, history, threshold=0.3): | |
| samples = self.transform(image).unsqueeze(0) | |
| samples = nested_tensor_from_tensor_list(samples) | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| samples = samples.to(device) | |
| vis = self.model.llm_decoder.vis_processors | |
| if task_button == "Question Answering": | |
| text = f"{text} Answer:" | |
| history.append(text) | |
| # prompt = " ".join(history) | |
| prompt = text | |
| elif task_button == "Captioning": | |
| prompt = "A photo of" | |
| else: | |
| prompt = text | |
| blip2_samples = { | |
| 'image': vis['eval'](image)[None, :].to(device), | |
| 'prompt': [prompt], | |
| } | |
| outputs = self.model(samples, blip2_samples, mask_infos=None, task_button=task_button) | |
| mask_infos = outputs['mask_infos_pred'] | |
| pred_names = [list(mask_info.values()) for mask_info in mask_infos] | |
| orig_target_sizes = torch.tensor([tuple(reversed(image.size))]).to(device) | |
| results = self.postprocessor(outputs, orig_target_sizes, pred_names, mask_infos)[0] | |
| image_vis = visualize_prediction(image, results, threshold) | |
| out_text = outputs['output_text'][0] | |
| if task_button == "Cloze Test": | |
| history = [] | |
| chat = [ | |
| (prompt, out_text), | |
| ] | |
| elif task_button == "Captioning": | |
| history = [] | |
| chat = [ | |
| ("please describe the image", out_text), | |
| ] | |
| elif task_button == "Question Answering": | |
| history += [out_text] | |
| chat = [ | |
| (history[i], history[i + 1]) for i in range(0, len(history) - 1, 2) | |
| ] | |
| else: | |
| history = [] | |
| chat = [] | |
| return image_vis, chat, history |