seemggoel's picture
Update model.py
aef78f4 verified
raw
history blame
3.87 kB
# Step 2: Import necessary libraries
import gradio as gr
from PIL import Image
from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftConfig, PeftModel
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.cache_utils import DynamicCache, StaticCache
# Step 3: Set device and default dtype
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_dtype(torch.float16)
# Step 4: Load CLIP model and processor
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.float16).to(DEVICE)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
# Step 5: Define the MultiModalModel class
class MultiModalModel(nn.Module):
def __init__(self, phi_model_name="microsoft/phi-3-mini-4k-instruct", clip_model_name="openai/clip-vit-base-patch32"):
super().__init__()
self.phi = None # Will be set after loading the PEFT model
self.tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
self.tokenizer.add_special_tokens({"additional_special_tokens": ["[IMG]"], "pad_token": "<pad>"})
self.clip = CLIPModel.from_pretrained(clip_model_name, torch_dtype=torch.float16).eval().to(DEVICE)
image_embedding_dim = self.clip.config.projection_dim
phi_hidden_size = 3072 # Hardcoded for Phi-3 mini
self.image_projection = nn.Sequential(
nn.Linear(image_embedding_dim, phi_hidden_size, dtype=torch.float16),
nn.LayerNorm(phi_hidden_size, dtype=torch.float16),
nn.Dropout(0.1)
).to(DEVICE)
nn.init.xavier_uniform_(self.image_projection[0].weight, gain=1.0)
nn.init.zeros_(self.image_projection[0].bias)
def forward(self, text_input_ids, attention_mask=None, image_embedding=None):
image_embedding = torch.clamp(image_embedding, min=-1e4, max=1e4)
image_embedding = F.normalize(image_embedding, dim=-1, eps=1e-5).to(torch.float16)
with torch.no_grad():
self.image_projection[0].weight.clamp_(-1.0, 1.0)
self.image_projection[0].bias.clamp_(-1.0, 1.0)
projected_image = 1.0 * self.image_projection(image_embedding)
projected_image = torch.clamp(projected_image, min=-1e4, max=1e4)
if torch.isnan(projected_image).any() or torch.isinf(projected_image).any():
print("Warning: Projected image contains NaN or Inf values after clamping, replacing with zeros")
projected_image = torch.where(
torch.logical_or(torch.isnan(projected_image), torch.isinf(projected_image)),
torch.zeros_like(projected_image),
projected_image
)
if projected_image.dim() == 2:
projected_image = projected_image.unsqueeze(1)
text_embeddings = self.phi.get_input_embeddings()(text_input_ids)
fused_embeddings = text_embeddings.clone()
img_token_id = self.tokenizer.convert_tokens_to_ids("[IMG]")
img_token_mask = (text_input_ids == img_token_id)
for i in range(fused_embeddings.shape[0]):
img_positions = img_token_mask[i].nonzero(as_tuple=True)[0]
if img_positions.numel() > 0:
fused_embeddings[i, img_positions[0], :] = projected_image[i, 0, :]
if torch.isnan(fused_embeddings).any() or torch.isinf(fused_embeddings).any():
print("Warning: Fused embeddings contain NaN or Inf values, replacing with zeros")
fused_embeddings = torch.where(
torch.logical_or(torch.isnan(fused_embeddings), torch.isinf(fused_embeddings)),
torch.zeros_like(fused_embeddings),
fused_embeddings
)
return fused_embeddings