Spaces:
Build error
Build error
| import argparse | |
| import requests | |
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| import torchvision | |
| from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
| from timm.data import create_transform | |
| from timmvit import timmvit | |
| import json | |
| from timm.models.hub import download_cached_file | |
| from PIL import Image | |
| def pil_loader(filepath): | |
| with Image.open(filepath) as img: | |
| img = img.convert('RGB') | |
| return img | |
| def build_transforms(input_size, center_crop=True): | |
| transform = torchvision.transforms.Compose([ | |
| torchvision.transforms.ToPILImage(), | |
| torchvision.transforms.Resize(input_size * 8 // 7), | |
| torchvision.transforms.CenterCrop(input_size), | |
| torchvision.transforms.ToTensor(), | |
| torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| return transform | |
| # Download human-readable labels for Bamboo. | |
| with open('./trainid2name.json') as f: | |
| id2name = json.load(f) | |
| ''' | |
| build model | |
| ''' | |
| model = timmvit(pretrain_path='./Bamboo_v0-1_ViT-B16.pth.tar.convert') | |
| model.eval() | |
| ''' | |
| borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py | |
| ''' | |
| def show_cam_on_image(img: np.ndarray, | |
| mask: np.ndarray, | |
| use_rgb: bool = False, | |
| colormap: int = cv2.COLORMAP_JET) -> np.ndarray: | |
| """ This function overlays the cam mask on the image as an heatmap. | |
| By default the heatmap is in BGR format. | |
| :param img: The base image in RGB or BGR format. | |
| :param mask: The cam mask. | |
| :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. | |
| :param colormap: The OpenCV colormap to be used. | |
| :returns: The default image with the cam overlay. | |
| """ | |
| heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) | |
| if use_rgb: | |
| heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) | |
| heatmap = np.float32(heatmap) / 255 | |
| if np.max(img) > 1: | |
| raise Exception( | |
| "The input image should np.float32 in the range [0, 1]") | |
| cam = 0.7*heatmap + 0.3*img | |
| # cam = cam / np.max(cam) | |
| return np.uint8(255 * cam) | |
| def recognize_image(image): | |
| img_t = eval_transforms(image) | |
| # compute output | |
| output = model(img_t.unsqueeze(0)) | |
| prediction = output.softmax(-1).flatten() | |
| _,top5_idx = torch.topk(prediction, 5) | |
| return {id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()} | |
| eval_transforms = build_transforms(224) | |
| image = gr.inputs.Image() | |
| label = gr.outputs.Label(num_top_classes=5) | |
| gr.Interface( | |
| description="Bamboo for Image Recognition Demo (https://github.com/Davidzhangyuanhan/Bamboo). Bamboo knows what this object is and what you are doing in a very fine-grain granularity: fratercula arctica (fig.5) and dribbler (fig.2)).", | |
| fn=recognize_image, | |
| inputs=["image"], | |
| outputs=[ | |
| label, | |
| ], | |
| examples=[ | |
| ["./examples/playing_mahjong.jpg"], | |
| ["./examples/dribbler.jpg"], | |
| ["./examples/Ferrari-F355.jpg"], | |
| ["./examples/northern_oriole.jpg"], | |
| ["./examples/fratercula_arctica.jpg"], | |
| ["./examples/husky.jpg"], | |
| ["./examples/taraxacum_erythrospermum.jpg"], | |
| ], | |
| ).launch() |