Spaces:
Build error
Build error
| # ------------------------------------------------------------------------ | |
| # Modified from DETR (https://github.com/facebookresearch/detr) | |
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| # ------------------------------------------------------------------------ | |
| import timm | |
| import torch | |
| import copy | |
| import torch.nn as nn | |
| import torchvision | |
| import json | |
| from timm.models.hub import download_cached_file | |
| from PIL import Image | |
| class MyViT(nn.Module): | |
| def __init__(self, num_classes=115217, pretrain_path=None, enable_fc=False): | |
| super().__init__() | |
| print('initializing ViT model as backbone using ckpt:', pretrain_path) | |
| self.model = timm.create_model('vit_base_patch16_224',checkpoint_path=pretrain_path,num_classes=num_classes)# pretrained=True) | |
| # def forward_features(self, x): | |
| # x = self.model.patch_embed(x) | |
| # cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks | |
| # if self.model.dist_token is None: | |
| # x = torch.cat((cls_token, x), dim=1) | |
| # else: | |
| # x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1) | |
| # x = self.model.pos_drop(x + self.model.pos_embed) | |
| # x = self.model.blocks(x) | |
| # x = self.model.norm(x) | |
| # return self.model.pre_logits(x[:, 0]) | |
| def forward(self, x): | |
| x = self.model.forward(x) | |
| return x | |
| def timmvit(**kwargs): | |
| default_kwargs={} | |
| default_kwargs.update(**kwargs) | |
| return MyViT(**default_kwargs) | |
| def build_transforms(input_size, center_crop=True): | |
| transform = torchvision.transforms.Compose([ | |
| torchvision.transforms.Resize(input_size * 8 // 7), | |
| torchvision.transforms.CenterCrop(input_size), | |
| torchvision.transforms.ToTensor(), | |
| torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| return transform | |
| def pil_loader(filepath): | |
| with Image.open(filepath) as img: | |
| img = img.convert('RGB') | |
| return img | |
| def test_build(): | |
| with open('/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/trainid2name.json') as f: | |
| id2name = json.load(f) | |
| img = pil_loader('/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/142520422_6ad756ddf6_w_d.jpg') | |
| eval_transforms = build_transforms(224) | |
| img_t = eval_transforms(img) | |
| img_t = img_t[None, :] | |
| model = MyViT(pretrain_path='/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/Bamboo_v0-1_ViT-B16.pth.tar.convert') | |
| # image = torch.rand(1, 3, 224, 224) | |
| output = model(img_t) | |
| # import pdb;pdb.set_trace() | |
| prediction = output.softmax(-1).flatten() | |
| _,top5_idx = torch.topk(prediction, 5) | |
| # import pdb;pdb.set_trace() | |
| print({id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()}) | |
| if __name__ == '__main__': | |
| test_build() | |