Spaces:
Runtime error
Runtime error
| # Code adapted from https://github.com/openai/CLIP/blob/main/ | |
| from transformers import CLIPProcessor, CLIPModel | |
| import argparse | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from datasets_classes_templates import data_seeds | |
| import numpy as np | |
| from datetime import datetime | |
| def zeroshot_classifier(classnames, templates, processor, model): | |
| with torch.no_grad(): | |
| zeroshot_weights = [] | |
| for classname in tqdm(classnames): | |
| texts = [template.format(classname) for template in templates] #format with class | |
| text_inputs = processor(text=texts, return_tensors="pt", padding=True, truncation=True).to('cuda') | |
| class_embeddings = model.get_text_features(text_inputs['input_ids']) #embed with text encoder | |
| class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) | |
| class_embedding = class_embeddings.mean(dim=0) | |
| class_embedding /= class_embedding.norm() | |
| zeroshot_weights.append(class_embedding) | |
| zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda() | |
| return zeroshot_weights | |
| def classification_collate_fn(batch): | |
| images, labels = zip(*batch) | |
| labels = torch.tensor(labels) | |
| return images, labels | |
| def accuracy(output, target, topk=(1,)): | |
| pred = output.topk(max(topk), 1, True, True)[1].t() | |
| correct = pred.eq(target.view(1, -1).expand_as(pred)) | |
| return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--data", type=str, default=None, choices=['non_fine_tuned','MS_COCO','medium','base','all'], help='Data on which clip was fine-tuned') | |
| parser.add_argument("--dataset", type=str, default="CIFAR10", choices=["CIFAR10", "CIFAR100", "ImageNet", "Caltech101", "Caltech256", "Food101"]) | |
| parser.add_argument("--method",type=str, default="COCO_CF", choices=['COCO_CF','APGD_1','APGD_4','NONE']) | |
| args = parser.parse_args() | |
| current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
| results_filename = f'./Results/fine_tuned_clip/zeroshot_image_classification_results_{args.dataset}_{args.data}_{args.method}_{current_time}.txt' | |
| with open(results_filename, 'w') as f: | |
| f.write(f'Arguments: {args}\n\n') | |
| if args.data == 'MS_COCO': | |
| assert args.method == 'NONE' and args.data == 'MS_COCO', 'Use NONE for method for MS_COCO data' | |
| imagenet_path = '/software/ais2t/pytorch_datasets/imagenet/' # Fill the path for imagenet here | |
| if args.dataset == "CIFAR10": | |
| from datasets_classes_templates import CIFAR10_CLASSES_TEMPLATES as classes_templates | |
| from torchvision.datasets import CIFAR10 | |
| data = CIFAR10(root='./image_classification_datasets/cifar10/', train=False, download=True) | |
| elif args.dataset == "CIFAR100": | |
| from datasets_classes_templates import CIFAR100_CLASSES_TEMPLATES as classes_templates | |
| from torchvision.datasets import CIFAR100 | |
| data = CIFAR100(root='./image_classification_datasets/cifar100/', train=False, download=True) | |
| elif args.dataset == "ImageNet": | |
| from datasets_classes_templates import ImageNet_CLASSES_TEMPLATES as classes_templates | |
| from torchvision.datasets import ImageNet | |
| data = ImageNet(root=imagenet_path, split='val') | |
| elif args.dataset == "Caltech101": | |
| torch.manual_seed(42) | |
| from datasets_classes_templates import Caltech101_CLASSES_TEMPLATES as classes_templates | |
| from torchvision.datasets import Caltech101 | |
| data = Caltech101(root='./image_classification_datasets/', download=False) | |
| train_size = int(0.8 * len(data)) # 80% for training | |
| val_size = len(data) - train_size | |
| _, data = torch.utils.data.random_split(data, [train_size, val_size]) | |
| elif args.dataset == "Caltech256": | |
| torch.manual_seed(42) | |
| from datasets_classes_templates import Caltech256_CLASSES_TEMPLATES as classes_templates | |
| from torchvision.datasets import Caltech256 | |
| data = Caltech256(root='./image_classification_datasets/', download=False) | |
| train_size = int(0.8 * len(data)) # 80% for training | |
| val_size = len(data) - train_size | |
| _, data = torch.utils.data.random_split(data, [train_size, val_size]) | |
| elif args.dataset == "Food101": | |
| from datasets_classes_templates import Food101_CLASSES_TEMPLATES as classes_templates | |
| from torchvision.datasets import Food101 | |
| data = Food101(root='./image_classification_datasets/food101/', download=True, split='test') | |
| else: | |
| raise NotImplementedError | |
| print(f'Conducting zero-shot image classification on {args.dataset}') | |
| device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |
| model_base_path = './fine_tuned_clip_models' | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| top1_list = [] | |
| for data_seed in data_seeds: | |
| print(f'Conducting zero-shot image classification on {args.data} with seed {data_seed} for the method {args.method}') | |
| model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) | |
| if args.data != 'non_fine_tuned': | |
| if args.method != 'NONE': | |
| if args.data not in ['all']: | |
| model.load_state_dict(torch.load(f'{model_base_path}/{args.method}/clip_model_dataset_{args.data}_method_{args.method}_num_epochs_20_data_seed_{data_seed}.pt')) | |
| else: | |
| model.load_state_dict(torch.load(f'{model_base_path}/{args.method}/clip_model_dataset_{args.data}_method_{args.method}_num_epochs_20.pt')) | |
| elif args.method == 'NONE' and args.data == 'MS_COCO': | |
| model.load_state_dict(torch.load(f'{model_base_path}/{args.method}/clip_model_dataset_{args.data}_method_{args.method}_num_epochs_20.pt')) | |
| model.eval() | |
| data_loader = DataLoader(data, batch_size=128, collate_fn=classification_collate_fn, shuffle=False) | |
| zeroshot_weights = zeroshot_classifier(classes_templates['classes'], | |
| classes_templates['templates'], | |
| processor, | |
| model | |
| ) | |
| with torch.no_grad(): | |
| top1, top5, n = 0., 0., 0. | |
| for i, (images, target) in enumerate(tqdm(data_loader)): | |
| target = target.to(device) | |
| images = list(images) | |
| images = processor(images=images, return_tensors="pt").to(device) | |
| # predict | |
| image_features = model.get_image_features(images['pixel_values']).to(device) | |
| image_features /= image_features.norm(dim=-1, keepdim=True) | |
| logits = 100. * image_features @ zeroshot_weights | |
| # measure accuracy | |
| acc1, acc5 = accuracy(logits, target, topk=(1, 5)) | |
| top1 += acc1 | |
| top5 += acc5 | |
| n += image_features.size(0) | |
| top1 = (top1 / n) * 100 | |
| top5 = (top5 / n) * 100 | |
| with open(results_filename, 'a') as f: | |
| f.write(f'Seed {data_seed}: Top-1 Accuracy: {top1:.2f}, Top-5 Accuracy: {top5:.2f}\n') | |
| top1_list.append(top1) | |
| print(f"Top-1 accuracy: {top1:.2f}") | |
| print(f"Top-5 accuracy: {top5:.2f}") | |
| print('-'*40) | |
| if args.method == 'NONE' or args.data in ['MS_COCO','all'] or args.data == 'non_fine_tuned': | |
| break | |
| top1 = np.asarray(top1_list) | |
| print(f'Mean of the top 1 accuracy is {np.mean(top1)}') | |
| print(f'Standard deviation of the top 1 accuracy is {np.std(top1)}') | |
| with open(results_filename, 'a') as f: | |
| f.write(f'\nMean Top-1 Accuracy: {np.mean(top1):.2f}\n') | |
| f.write(f'Standard Deviation of Top-1 Accuracy: {np.std(top1):.2f}\n') | |
| if __name__ == "__main__": | |
| main() | |