Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, AutoModelForCausalLM | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # Force CPU usage | |
| DEVICE = torch.device("cpu") | |
| # Use lower precision to save memory | |
| torch.set_default_dtype(torch.float32) | |
| class MultiModalModel(nn.Module): | |
| def __init__(self, phi_model_name="microsoft/phi-3-mini-4k-instruct", clip_model_name="openai/clip-vit-base-patch32"): | |
| super().__init__() | |
| # Load LLM without quantization for CPU compatibility | |
| self.phi = AutoModelForCausalLM.from_pretrained( | |
| phi_model_name, | |
| return_dict=True, | |
| device_map="cpu", | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=False | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=False) | |
| self.tokenizer.add_special_tokens({"additional_special_tokens": ["[IMG]"], "pad_token": "<pad>"}) | |
| self.phi.resize_token_embeddings(len(self.tokenizer)) | |
| # Load CLIP model | |
| self.clip = CLIPModel.from_pretrained(clip_model_name).to(DEVICE) | |
| self.clip_processor = CLIPProcessor.from_pretrained(clip_model_name, use_fast=True) | |
| # Image projection layer | |
| image_embedding_dim = self.clip.config.projection_dim | |
| phi_hidden_size = self.phi.config.hidden_size | |
| self.image_projection = nn.Sequential( | |
| nn.Linear(image_embedding_dim, image_embedding_dim * 2), | |
| nn.GELU(), | |
| nn.Linear(image_embedding_dim * 2, phi_hidden_size), | |
| nn.LayerNorm(phi_hidden_size), | |
| nn.Dropout(0.1) | |
| ) | |
| # Initialize weights | |
| nn.init.xavier_uniform_(self.image_projection[0].weight, gain=1.0) | |
| nn.init.zeros_(self.image_projection[0].bias) | |
| nn.init.xavier_uniform_(self.image_projection[2].weight, gain=1.0) | |
| nn.init.zeros_(self.image_projection[2].bias) | |
| def forward(self, text_input_ids, attention_mask=None, image_embedding=None, labels=None): | |
| image_embedding = F.normalize(image_embedding, dim=-1) | |
| projected_image = 10.0 * self.image_projection(image_embedding) # Amplify image signal | |
| if projected_image.dim() == 2: | |
| projected_image = projected_image.unsqueeze(1) | |
| text_embeddings = self.phi.get_input_embeddings()(text_input_ids) | |
| img_token_id = self.tokenizer.convert_tokens_to_ids("[IMG]") | |
| img_token_mask = (text_input_ids == img_token_id) | |
| fused_embeddings = text_embeddings.clone() | |
| for i in range(fused_embeddings.shape[0]): | |
| img_positions = img_token_mask[i].nonzero(as_tuple=True)[0] | |
| if img_positions.numel() > 0: | |
| fused_embeddings[i, img_positions[0], :] = projected_image[i, 0, :] | |
| return fused_embeddings | |
| def process_image(self, image): | |
| """Process an image through CLIP to get embeddings""" | |
| image_inputs = self.clip_processor(images=image, return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| image_embedding = self.clip.get_image_features(**image_inputs) | |
| return image_embedding | |
| def generate_description(self, image, prompt_template="[IMG] A detailed description of this image is:", max_tokens=100): | |
| """End-to-end generation from image to text description""" | |
| # Process image | |
| if isinstance(image, str): | |
| image = Image.open(image).convert("RGB") | |
| elif not isinstance(image, Image.Image): | |
| image = Image.fromarray(image).convert("RGB") | |
| # Process text prompt | |
| tokenized = self.tokenizer(prompt_template, return_tensors="pt", truncation=True, max_length=128) | |
| text_input_ids = tokenized["input_ids"].to(DEVICE) | |
| attention_mask = tokenized["attention_mask"].to(DEVICE) | |
| # Get image embedding | |
| image_embedding = self.process_image(image) | |
| # Generate description | |
| with torch.no_grad(): | |
| fused_embeddings = self( | |
| text_input_ids=text_input_ids, | |
| attention_mask=attention_mask, | |
| image_embedding=image_embedding | |
| ) | |
| generated_ids = self.phi.generate( | |
| inputs_embeds=fused_embeddings, | |
| attention_mask=attention_mask, | |
| max_new_tokens=max_tokens, | |
| do_sample=False, # Greedy decoding for deterministic output | |
| repetition_penalty=1.2 | |
| ) | |
| output = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
| return output | |
| def load_weights(self, weights_path): | |
| """Load saved weights for the image projection layer""" | |
| try: | |
| state_dict = torch.load(weights_path, map_location=DEVICE) | |
| self.image_projection.load_state_dict(state_dict) | |
| return True | |
| except Exception as e: | |
| print(f"Failed to load weights: {e}") | |
| return False | |
| # Global model instance (will be loaded on demand) | |
| model = None | |
| def load_model(): | |
| """Load the model if not already loaded""" | |
| global model | |
| if model is None: | |
| print("Loading models. This may take a few minutes...") | |
| model = MultiModalModel().to(DEVICE) | |
| print("Models loaded!") | |
| return model | |
| def generate_description(image, prompt, max_length): | |
| """Generate a description for the given image""" | |
| try: | |
| model = load_model() | |
| result = model.generate_description(image, prompt, int(max_length)) | |
| return result | |
| except Exception as e: | |
| return f"Error generating description: {str(e)}" | |
| def load_projection_weights(weights_file): | |
| """Load custom projection weights""" | |
| try: | |
| model = load_model() | |
| success = model.load_weights(weights_file.name) | |
| if success: | |
| return "β Projection weights loaded successfully!" | |
| else: | |
| return "β Failed to load weights" | |
| except Exception as e: | |
| return f"β Error: {str(e)}" | |
| def create_interface(): | |
| """Create and return the Gradio interface""" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Multimodal Image Description with Phi-3 Mini") | |
| with gr.Tab("Generate"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(label="Upload Image", type="pil") | |
| prompt_input = gr.Textbox( | |
| label="Prompt (use [IMG] for image placement)", | |
| value="[IMG] A detailed description of this image is:", | |
| lines=2 | |
| ) | |
| max_length = gr.Slider( | |
| minimum=10, maximum=300, value=100, step=10, | |
| label="Maximum Output Length" | |
| ) | |
| submit_btn = gr.Button("Generate Description") | |
| with gr.Column(): | |
| output_text = gr.Textbox(label="Generated Description", lines=12) | |
| submit_btn.click( | |
| generate_description, | |
| inputs=[image_input, prompt_input, max_length], | |
| outputs=output_text | |
| ) | |
| with gr.Tab("Advanced"): | |
| gr.Markdown("### Load Custom Projection Weights") | |
| weights_file = gr.File(label="Upload Projection Weights (.pt file)") | |
| load_btn = gr.Button("Load Weights") | |
| weight_status = gr.Textbox(label="Status") | |
| load_btn.click( | |
| load_projection_weights, | |
| inputs=[weights_file], | |
| outputs=weight_status | |
| ) | |
| gr.Markdown(""" | |
| ### About This Model | |
| This app uses: | |
| - CLIP (ViT-B/32) to extract image features | |
| - Phi-3 Mini for text generation | |
| - A projection layer to connect image and text spaces | |
| For optimal performance, upload projection weights trained for this specific setup. | |
| """) | |
| return demo | |
| # Optional: For testing directly from this file | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.queue() | |
| demo.launch() |