# Step 2: Import necessary libraries import gradio as gr from PIL import Image from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from peft import PeftConfig, PeftModel import torch import torch.nn as nn import torch.nn.functional as F from transformers.cache_utils import DynamicCache, StaticCache # Step 3: Set device and default dtype DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.set_default_dtype(torch.float32 if DEVICE.type == "cpu" else torch.float16) # Step 4: Load CLIP model and processor clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.float32 if DEVICE.type == "cpu" else torch.float16).to(DEVICE) clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True) # Step 5: Define the MultiModalModel class class MultiModalModel(nn.Module): def __init__(self, phi_model_name="microsoft/phi-3-mini-4k-instruct", clip_model_name="openai/clip-vit-base-patch32"): super().__init__() self.phi = None # Will be set after loading the PEFT model self.tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True) self.tokenizer.add_special_tokens({"additional_special_tokens": ["[IMG]"], "pad_token": ""}) self.clip = CLIPModel.from_pretrained(clip_model_name, torch_dtype=torch.float32 if DEVICE.type == "cpu" else torch.float16).eval().to(DEVICE) image_embedding_dim = self.clip.config.projection_dim phi_hidden_size = 3072 # Hardcoded for Phi-3 mini self.image_projection = nn.Sequential( nn.Linear(image_embedding_dim, phi_hidden_size, dtype=torch.float32 if DEVICE.type == "cpu" else torch.float16), nn.LayerNorm(phi_hidden_size, dtype=torch.float32 if DEVICE.type == "cpu" else torch.float16), nn.Dropout(0.1) ).to(DEVICE) nn.init.xavier_uniform_(self.image_projection[0].weight, gain=1.0) nn.init.zeros_(self.image_projection[0].bias) def forward(self, text_input_ids, attention_mask=None, image_embedding=None): image_embedding = torch.clamp(image_embedding, min=-1e4, max=1e4) image_embedding = F.normalize(image_embedding, dim=-1, eps=1e-5).to(torch.float32 if DEVICE.type == "cpu" else torch.float16) with torch.no_grad(): self.image_projection[0].weight.clamp_(-1.0, 1.0) self.image_projection[0].bias.clamp_(-1.0, 1.0) projected_image = 1.0 * self.image_projection(image_embedding) projected_image = torch.clamp(projected_image, min=-1e4, max=1e4) if torch.isnan(projected_image).any() or torch.isinf(projected_image).any(): print("Warning: Projected image contains NaN or Inf values after clamping, replacing with zeros") projected_image = torch.where( torch.logical_or(torch.isnan(projected_image), torch.isinf(projected_image)), torch.zeros_like(projected_image), projected_image ) if projected_image.dim() == 2: projected_image = projected_image.unsqueeze(1) text_embeddings = self.phi.get_input_embeddings()(text_input_ids) fused_embeddings = text_embeddings.clone() img_token_id = self.tokenizer.convert_tokens_to_ids("[IMG]") img_token_mask = (text_input_ids == img_token_id) for i in range(fused_embeddings.shape[0]): img_positions = img_token_mask[i].nonzero(as_tuple=True)[0] if img_positions.numel() > 0: fused_embeddings[i, img_positions[0], :] = projected_image[i, 0, :] if torch.isnan(fused_embeddings).any() or torch.isinf(fused_embeddings).any(): print("Warning: Fused embeddings contain NaN or Inf values, replacing with zeros") fused_embeddings = torch.where( torch.logical_or(torch.isnan(fused_embeddings), torch.isinf(fused_embeddings)), torch.zeros_like(fused_embeddings), fused_embeddings ) return fused_embeddings