Robust_MMFM / vlm_eval /clip_train.py
KC123hello's picture
Upload Files
fc0ff8f verified
raw
history blame
9.67 kB
# Code adapted from https://github.com/ylaxor/clip-like/blob/main/fine-tune-clip.ipynb
from random import seed, shuffle
from typing import Callable
import torch
from tqdm import tqdm
from transformers import CLIPProcessor, CLIPModel
from timm.scheduler import CosineLRScheduler
class ModelTrainer:
def __init__(self,
model: Callable,
processor: Callable,
data_name: str,
train_data_loader: torch.utils.data.DataLoader,
val_data_loader: torch.utils.data.DataLoader,
num_epochs: int,
learning_rate: float = 5e-5,
weight_decay: float = 1e-3,
device: str = "cuda:0",
save_model: bool = False,
save_model_path: str = "./fine_tuned_clip_models",
data_seed: int = 42,
method="COCO_CF",
) -> None:
self.model = model
self.processor = processor
self.data_name = data_name
self.train_data_loader = train_data_loader
self.val_data_loader = val_data_loader
self.num_epochs = num_epochs
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.device = device
self.save_model = save_model
self.save_model_path = save_model_path
self.data_seed = data_seed
self.method = method
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay
)
def train(self):
self.model.train()
lr_scheduler = CosineLRScheduler(
self.optimizer,
t_initial=self.num_epochs * len(self.train_data_loader),
lr_min=2e-7,
warmup_lr_init=1e-7,
warmup_prefix=True,
warmup_t=3,
cycle_limit=1,
t_in_epochs=False,
)
progress_bar = tqdm(range(self.num_epochs))
for epoch in progress_bar:
running_loss = 0.0
for batch_idx, batch in enumerate(self.train_data_loader):
self.optimizer.zero_grad()
processed_input = self.processor(text=batch["caption"],
images=batch["image"],
return_tensors="pt",
padding=True,
max_length=128,
truncation=True
)
outputs = self.model(input_ids=processed_input['input_ids'].squeeze().to(self.device),
pixel_values=processed_input['pixel_values'].squeeze().to(self.device),
attention_mask=processed_input['attention_mask'].squeeze().to(self.device),
return_loss=True
)
loss = outputs.loss
loss.backward()
running_loss += loss.item() * len(batch["caption"])
self.optimizer.step()
lr_scheduler.step_update(batch_idx + epoch * len(self.train_data_loader))
print(f"Epoch {epoch+1}/{self.num_epochs} Loss: {running_loss/len(self.train_data_loader.dataset):.4f}")
progress_bar.set_postfix(
epoch="{}/{}".format(epoch+1,self.num_epochs),
loss=running_loss/len(self.train_data_loader.dataset),
lr=self.optimizer.param_groups[0]["lr"]
)
if self.save_model:
if self.data_name not in ['MS_COCO','all']:
torch.save(self.model.state_dict(), self.save_model_path + f'clip_model_dataset_{self.data_name}_method_{self.method}_num_epochs_{self.num_epochs}_data_seed_{self.data_seed}.pt')
print(f"Saving fine-tuned model as clip_model_dataset_{self.data_name}_method_{self.method}_num_epochs_{self.num_epochs}_data_seed_{self.data_seed}.pt")
else:
torch.save(self.model.state_dict(), self.save_model_path + f'clip_model_dataset_{self.data_name}_method_{self.method}_num_epochs_{self.num_epochs}.pt')
print(f"Saving fine-tuned model as clip_model_dataset_{self.data_name}_method_{self.method}_num_epochs_{self.num_epochs}.pt")
def eval(self):
self.model.eval()
nb_batches = len(self.val_data_loader)
tqdm_object = tqdm(self.val_data_loader, total=len(self.val_data_loader))
epoch_loss = 0.0
for i, batch in enumerate(tqdm_object):
processed_input = self.processor(text=batch["caption"],
images=batch["image"],
return_tensors="pt",
padding=True,
max_length=128,
truncation=True
)
outputs = self.model(
input_ids=processed_input['input_ids'].squeeze().to(self.device),
attention_mask=processed_input['attention_mask'].squeeze().to(self.device),
pixel_values=processed_input['pixel_values'].squeeze().to(self.device),
return_loss=True)
loss, logits_per_image = outputs.loss, outputs.logits_per_image
epoch_loss += loss.item()
tqdm_object.set_postfix(
batch="{}/{}".format(i+1,nb_batches),
dev_loss=loss.item(),
)
epoch_loss = epoch_loss / nb_batches
print(f"Eval loss: {epoch_loss}")
def main():
import os
#os.environ['HF_HOME'] = '' Add path for saved hugging face models
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--num_epochs', type=int, default=20)
parser.add_argument('--data_name', type=str, default="MS_COCO", choices=["MS_COCO","base","medium","all"])
parser.add_argument('--learning_rate', type=float, default=1e-5)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--save_model', action='store_true', default=False)
parser.add_argument('--method', type=str, choices=['COCO_CF','APGD_1','APGD_4','NONE'])
parser.add_argument('--save_model_path', type=str, default="./fine_tuned_clip_models")
parser.add_argument(
"--data_seeds",
nargs="+",
type=int,
default=[107],
help="Seeds to use for each trial for picking demonstrations and eval sets",
)
args = parser.parse_args()
if args.data_name == 'MS_COCO':
assert args.data_name == 'MS_COCO' and args.method == 'NONE', "Only NONE method is allowed with MS_COCO dataset"
from torch.utils.data import DataLoader
from coco_cf_loader import MS_COCO_dataset, custom_collate_fn
torch.manual_seed(42)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
for data_seed in args.data_seeds:
if args.data_name not in ['MS_COCO', 'all']:
print(f"Data Seed: {data_seed} | Data Name: {args.data_name} | Method: {args.method}")
dataset = MS_COCO_dataset(base_dir=f'./clip_train_datasets/MS_COCO_{args.method}',
annotation_file=f'/json_files/data_name_{args.data_name}_data_seed_{data_seed}.json')
elif args.data_name == 'all':
print(f"Data Name: {args.data_name} | Method: {args.method}")
dataset = MS_COCO_dataset(base_dir=f'./clip_train_datasets/MS_COCO_{args.method}',
annotation_file=f'/json_files/data_name_{args.data_name}.json')
else:
print(f"Data Name: {args.data_name} | Method: {args.method}")
dataset = MS_COCO_dataset(base_dir=f'./clip_train_datasets/MS_COCO',
annotation_file=f'/ms_coco_captions.json')
train_size = int(0.8 * len(dataset)) # 80% for training
val_size = len(dataset) - train_size # 20% for validation
# Randomly split into training and validation datasets
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
# Optional: Create DataLoaders for each subset
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=custom_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, collate_fn=custom_collate_fn,drop_last=True)
trainer = ModelTrainer(model=model,
processor=processor,
data_name=args.data_name,
train_data_loader=train_loader,
val_data_loader=val_loader,
num_epochs=args.num_epochs,
learning_rate=args.learning_rate,
weight_decay=1e-3,
device=device,
data_seed=data_seed,
save_model=args.save_model,
save_model_path=args.save_model_path,
method=args.method,
)
trainer.train()
trainer.eval()
if args.data_name in ['MS_COCO','all']:
break
if __name__ == "__main__":
main()