# Code adapted from https://github.com/ylaxor/clip-like/blob/main/fine-tune-clip.ipynb from random import seed, shuffle from typing import Callable import torch from tqdm import tqdm from transformers import CLIPProcessor, CLIPModel from timm.scheduler import CosineLRScheduler class ModelTrainer: def __init__(self, model: Callable, processor: Callable, data_name: str, train_data_loader: torch.utils.data.DataLoader, val_data_loader: torch.utils.data.DataLoader, num_epochs: int, learning_rate: float = 5e-5, weight_decay: float = 1e-3, device: str = "cuda:0", save_model: bool = False, save_model_path: str = "./fine_tuned_clip_models", data_seed: int = 42, method="COCO_CF", ) -> None: self.model = model self.processor = processor self.data_name = data_name self.train_data_loader = train_data_loader self.val_data_loader = val_data_loader self.num_epochs = num_epochs self.learning_rate = learning_rate self.weight_decay = weight_decay self.device = device self.save_model = save_model self.save_model_path = save_model_path self.data_seed = data_seed self.method = method self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay ) def train(self): self.model.train() lr_scheduler = CosineLRScheduler( self.optimizer, t_initial=self.num_epochs * len(self.train_data_loader), lr_min=2e-7, warmup_lr_init=1e-7, warmup_prefix=True, warmup_t=3, cycle_limit=1, t_in_epochs=False, ) progress_bar = tqdm(range(self.num_epochs)) for epoch in progress_bar: running_loss = 0.0 for batch_idx, batch in enumerate(self.train_data_loader): self.optimizer.zero_grad() processed_input = self.processor(text=batch["caption"], images=batch["image"], return_tensors="pt", padding=True, max_length=128, truncation=True ) outputs = self.model(input_ids=processed_input['input_ids'].squeeze().to(self.device), pixel_values=processed_input['pixel_values'].squeeze().to(self.device), attention_mask=processed_input['attention_mask'].squeeze().to(self.device), return_loss=True ) loss = outputs.loss loss.backward() running_loss += loss.item() * len(batch["caption"]) self.optimizer.step() lr_scheduler.step_update(batch_idx + epoch * len(self.train_data_loader)) print(f"Epoch {epoch+1}/{self.num_epochs} Loss: {running_loss/len(self.train_data_loader.dataset):.4f}") progress_bar.set_postfix( epoch="{}/{}".format(epoch+1,self.num_epochs), loss=running_loss/len(self.train_data_loader.dataset), lr=self.optimizer.param_groups[0]["lr"] ) if self.save_model: if self.data_name not in ['MS_COCO','all']: torch.save(self.model.state_dict(), self.save_model_path + f'clip_model_dataset_{self.data_name}_method_{self.method}_num_epochs_{self.num_epochs}_data_seed_{self.data_seed}.pt') print(f"Saving fine-tuned model as clip_model_dataset_{self.data_name}_method_{self.method}_num_epochs_{self.num_epochs}_data_seed_{self.data_seed}.pt") else: torch.save(self.model.state_dict(), self.save_model_path + f'clip_model_dataset_{self.data_name}_method_{self.method}_num_epochs_{self.num_epochs}.pt') print(f"Saving fine-tuned model as clip_model_dataset_{self.data_name}_method_{self.method}_num_epochs_{self.num_epochs}.pt") def eval(self): self.model.eval() nb_batches = len(self.val_data_loader) tqdm_object = tqdm(self.val_data_loader, total=len(self.val_data_loader)) epoch_loss = 0.0 for i, batch in enumerate(tqdm_object): processed_input = self.processor(text=batch["caption"], images=batch["image"], return_tensors="pt", padding=True, max_length=128, truncation=True ) outputs = self.model( input_ids=processed_input['input_ids'].squeeze().to(self.device), attention_mask=processed_input['attention_mask'].squeeze().to(self.device), pixel_values=processed_input['pixel_values'].squeeze().to(self.device), return_loss=True) loss, logits_per_image = outputs.loss, outputs.logits_per_image epoch_loss += loss.item() tqdm_object.set_postfix( batch="{}/{}".format(i+1,nb_batches), dev_loss=loss.item(), ) epoch_loss = epoch_loss / nb_batches print(f"Eval loss: {epoch_loss}") def main(): import os #os.environ['HF_HOME'] = '' Add path for saved hugging face models import argparse parser = argparse.ArgumentParser() parser.add_argument('--num_epochs', type=int, default=20) parser.add_argument('--data_name', type=str, default="MS_COCO", choices=["MS_COCO","base","medium","all"]) parser.add_argument('--learning_rate', type=float, default=1e-5) parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--save_model', action='store_true', default=False) parser.add_argument('--method', type=str, choices=['COCO_CF','APGD_1','APGD_4','NONE']) parser.add_argument('--save_model_path', type=str, default="./fine_tuned_clip_models") parser.add_argument( "--data_seeds", nargs="+", type=int, default=[107], help="Seeds to use for each trial for picking demonstrations and eval sets", ) args = parser.parse_args() if args.data_name == 'MS_COCO': assert args.data_name == 'MS_COCO' and args.method == 'NONE', "Only NONE method is allowed with MS_COCO dataset" from torch.utils.data import DataLoader from coco_cf_loader import MS_COCO_dataset, custom_collate_fn torch.manual_seed(42) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") for data_seed in args.data_seeds: if args.data_name not in ['MS_COCO', 'all']: print(f"Data Seed: {data_seed} | Data Name: {args.data_name} | Method: {args.method}") dataset = MS_COCO_dataset(base_dir=f'./clip_train_datasets/MS_COCO_{args.method}', annotation_file=f'/json_files/data_name_{args.data_name}_data_seed_{data_seed}.json') elif args.data_name == 'all': print(f"Data Name: {args.data_name} | Method: {args.method}") dataset = MS_COCO_dataset(base_dir=f'./clip_train_datasets/MS_COCO_{args.method}', annotation_file=f'/json_files/data_name_{args.data_name}.json') else: print(f"Data Name: {args.data_name} | Method: {args.method}") dataset = MS_COCO_dataset(base_dir=f'./clip_train_datasets/MS_COCO', annotation_file=f'/ms_coco_captions.json') train_size = int(0.8 * len(dataset)) # 80% for training val_size = len(dataset) - train_size # 20% for validation # Randomly split into training and validation datasets train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) # Optional: Create DataLoaders for each subset train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=custom_collate_fn) val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, collate_fn=custom_collate_fn,drop_last=True) trainer = ModelTrainer(model=model, processor=processor, data_name=args.data_name, train_data_loader=train_loader, val_data_loader=val_loader, num_epochs=args.num_epochs, learning_rate=args.learning_rate, weight_decay=1e-3, device=device, data_seed=data_seed, save_model=args.save_model, save_model_path=args.save_model_path, method=args.method, ) trainer.train() trainer.eval() if args.data_name in ['MS_COCO','all']: break if __name__ == "__main__": main()