Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftConfig, PeftModel | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import gc | |
| import logging | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| logger.info(f"Gradio version: {gr.__version__}") | |
| # Device setup | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| torch.set_default_dtype(torch.float16) | |
| logger.info(f"Using device: {DEVICE}") | |
| class MultiModalModel(nn.Module): | |
| def __init__(self, phi_model_name="microsoft/phi-3-mini-4k-instruct", | |
| clip_model_name="openai/clip-vit-base-patch32", peft_model_path=None): | |
| super().__init__() | |
| logger.info("Loading CLIP model...") | |
| self.clip = CLIPModel.from_pretrained(clip_model_name, torch_dtype=torch.float16).to(DEVICE) | |
| self.clip_processor = CLIPProcessor.from_pretrained(clip_model_name, use_fast=True) | |
| logger.info("Loading language model...") | |
| if peft_model_path: | |
| logger.info(f"Loading PEFT model from {peft_model_path}") | |
| try: | |
| config = PeftConfig.from_pretrained(peft_model_path) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| config.base_model_name_or_path, | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True, | |
| device_map=DEVICE | |
| ) | |
| self.phi = PeftModel.from_pretrained(base_model, peft_model_path) | |
| self.tokenizer = AutoTokenizer.from_pretrained(peft_model_path) | |
| except Exception as e: | |
| logger.error(f"Failed to load PEFT model: {str(e)}", exc_info=True) | |
| raise | |
| else: | |
| logger.info(f"Loading base model {phi_model_name}") | |
| self.phi = AutoModelForCausalLM.from_pretrained( | |
| phi_model_name, | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True, | |
| device_map=DEVICE | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained(phi_model_name) | |
| self.tokenizer.add_special_tokens({"additional_special_tokens": ["[IMG]"], "pad_token": "<pad>"}) | |
| self.phi.resize_token_embeddings(len(self.tokenizer)) | |
| image_embedding_dim = self.clip.config.projection_dim | |
| phi_hidden_size = self.phi.config.hidden_size | |
| self.image_projection = nn.Sequential( | |
| nn.Linear(image_embedding_dim, image_embedding_dim * 2), | |
| nn.GELU(), | |
| nn.Linear(image_embedding_dim * 2, phi_hidden_size), | |
| nn.LayerNorm(phi_hidden_size), | |
| nn.Dropout(0.1) | |
| ).to(DEVICE) | |
| def forward(self, text_input_ids, attention_mask=None, image_embedding=None): | |
| image_embedding = F.normalize(image_embedding, dim=-1) | |
| projected_image = 10.0 * self.image_projection(image_embedding) | |
| if projected_image.dim() == 2: | |
| projected_image = projected_image.unsqueeze(1) | |
| text_embeddings = self.phi.get_input_embeddings()(text_input_ids) | |
| img_token_id = self.tokenizer.convert_tokens_to_ids("[IMG]") | |
| img_token_mask = (text_input_ids == img_token_id) | |
| fused_embeddings = text_embeddings.clone() | |
| for i in range(fused_embeddings.shape[0]): | |
| img_positions = img_token_mask[i].nonzero(as_tuple=True)[0] | |
| if img_positions.numel() > 0: | |
| fused_embeddings[i, img_positions[0], :] = projected_image[i, 0, :] | |
| return fused_embeddings | |
| def process_image(self, image): | |
| image_inputs = self.clip_processor(images=image, return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| image_embedding = self.clip.get_image_features(**image_inputs) | |
| return image_embedding | |
| def generate_description(self, image, prompt_template="[IMG] A detailed description of this image is:", max_tokens=100): | |
| if isinstance(image, str): | |
| image = Image.open(image).convert("RGB") | |
| elif not isinstance(image, Image.Image): | |
| image = Image.fromarray(image).convert("RGB") | |
| image = image.resize((224, 224), Image.LANCZOS) | |
| tokenized = self.tokenizer(prompt_template, return_tensors="pt", truncation=True, max_length=128) | |
| text_input_ids = tokenized["input_ids"].to(DEVICE) | |
| attention_mask = tokenized["attention_mask"].to(DEVICE) | |
| image_embedding = self.process_image(image) | |
| with torch.no_grad(): | |
| fused_embeddings = self( | |
| text_input_ids=text_input_ids, | |
| attention_mask=attention_mask, | |
| image_embedding=image_embedding | |
| ) | |
| generated_ids = self.phi.generate( | |
| inputs_embeds=fused_embeddings, | |
| attention_mask=attention_mask, | |
| max_new_tokens=max_tokens, | |
| do_sample=False, | |
| repetition_penalty=1.2 | |
| ) | |
| output = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
| return output | |
| model = None | |
| def load_model(peft_model_path=None): | |
| global model | |
| if model is None: | |
| logger.info("Loading model...") | |
| try: | |
| model = MultiModalModel(peft_model_path=peft_model_path) | |
| logger.info("Model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {str(e)}", exc_info=True) | |
| raise | |
| gc.collect() | |
| if DEVICE.type == "cuda": | |
| torch.cuda.empty_cache() | |
| return model | |
| def generate_description(image, prompt, max_length): | |
| logger.info("Generating description...") | |
| try: | |
| model = load_model(peft_model_path=os.getenv("model_V1", None)) | |
| if image is None: | |
| logger.error("No image provided") | |
| return "Error: No image provided" | |
| result = model.generate_description(image, prompt, int(max_length)) | |
| logger.info("Description generated successfully") | |
| gc.collect() | |
| return result | |
| except Exception as e: | |
| logger.error(f"Error generating description: {str(e)}", exc_info=True) | |
| return f"Error: {str(e)}" | |
| import gradio as gr | |
| # Gradio interface | |
| def create_gradio_interface(generate_fn, color_theme=COLOR_THEME): | |
| # Set color variables based on theme | |
| if color_theme == "blue": | |
| primary_gradient = "linear-gradient(145deg, #e0f2fe, #dbeafe)" | |
| header_gradient = "linear-gradient(135deg, #bfdbfe, #93c5fd)" # Blue gradient for header | |
| button_gradient = "linear-gradient(135deg, #3b82f6, #1d4ed8)" | |
| button_hover_gradient = "linear-gradient(135deg, #2563eb, #1e40af)" | |
| primary_color = "#1e40af" | |
| icon_color = "#2563eb" | |
| shadow_color = "rgba(59, 130, 246, 0.15)" | |
| button_shadow = "rgba(29, 78, 216, 0.25)" | |
| else: | |
| primary_gradient = "linear-gradient(145deg, #fff7ed, #ffedd5)" | |
| header_gradient = "linear-gradient(135deg, #fed7aa, #fdba74)" # Orange gradient for header | |
| button_gradient = "linear-gradient(135deg, #f97316, #ea580c)" | |
| button_hover_gradient = "linear-gradient(135deg, #ea580c, #c2410c)" | |
| primary_color = "#9a3412" | |
| icon_color = "#ea580c" | |
| shadow_color = "rgba(249, 115, 22, 0.15)" | |
| button_shadow = "rgba(234, 88, 12, 0.25)" | |
| # Gradio interface | |
| def create_gradio_interface(generate_fn, color_theme=COLOR_THEME): | |
| # Set color variables based on theme | |
| if color_theme == "blue": | |
| primary_gradient = "linear-gradient(145deg, #e0f2fe, #dbeafe)" | |
| header_gradient = "linear-gradient(135deg, #bfdbfe, #93c5fd)" | |
| header_background = "#dbeafe" # Light blue for section headers | |
| button_gradient = "linear-gradient(135deg, #3b82f6, #1d4ed8)" | |
| button_hover_gradient = "linear-gradient(135deg, #2563eb, #1e40af)" | |
| primary_color = "#1e40af" | |
| icon_color = "#2563eb" | |
| shadow_color = "rgba(59, 130, 246, 0.15)" | |
| button_shadow = "rgba(29, 78, 216, 0.25)" | |
| else: | |
| primary_gradient = "linear-gradient(145deg, #fff7ed, #ffedd5)" | |
| header_gradient = "linear-gradient(135deg, #fed7aa, #fdba74)" | |
| header_background = "#ffedd5" # Light orange for section headers | |
| button_gradient = "linear-gradient(135deg, #f97316, #ea580c)" | |
| button_hover_gradient = "linear-gradient(135deg, #ea580c, #c2410c)" | |
| primary_color = "#9a3412" | |
| icon_color = "#ea580c" | |
| shadow_color = "rgba(249, 115, 22, 0.15)" | |
| button_shadow = "rgba(234, 88, 12, 0.25)" | |
| # Custom CSS with dynamic color variables | |
| custom_css = f""" | |
| body {{ | |
| font-family: 'Inter', 'Segoe UI', sans-serif; | |
| background-color: #f8fafc; | |
| }} | |
| .container {{ | |
| background: {primary_gradient}; | |
| border-radius: 16px; | |
| padding: 30px; | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| box-shadow: 0 10px 25px {shadow_color}; | |
| }} | |
| .app-header {{ | |
| text-align: center; | |
| margin-bottom: 30px; | |
| background: {header_gradient}; | |
| border-radius: 12px; | |
| padding: 20px; | |
| box-shadow: 0 4px 12px {shadow_color}; | |
| }} | |
| .app-title {{ | |
| color: {primary_color}; | |
| font-size: 2.2em; | |
| font-weight: 700; | |
| margin-bottom: 10px; | |
| }} | |
| .app-description {{ | |
| color: #334155; | |
| font-size: 1.1em; | |
| line-height: 1.5; | |
| max-width: 700px; | |
| margin: 0 auto; | |
| }} | |
| .card {{ | |
| background: #ffffff; | |
| border-radius: 12px; | |
| padding: 20px; | |
| margin-bottom: 20px; | |
| box-shadow: 0 4px 12px rgba(0,0,0,0.05); | |
| border: 1px solid rgba(226, 232, 240, 0.8); | |
| transition: transform 0.2s, box-shadow 0.2s; | |
| height: 100%; | |
| }} | |
| .card:hover {{ | |
| transform: translateY(-2px); | |
| box-shadow: 0 6px 16px rgba(0,0,0,0.08); | |
| }} | |
| .input-label {{ | |
| color: {primary_color}; | |
| font-weight: 600; | |
| margin-bottom: 8px; | |
| font-size: 1.05em; | |
| background: {header_background}; /* Add background to section headers */ | |
| padding: 5px 10px; | |
| border-radius: 6px; | |
| display: inline-block; | |
| }} | |
| .output-card {{ | |
| background: #ffffff; | |
| border-radius: 12px; | |
| padding: 25px; | |
| border: 1px solid rgba(226, 232, 240, 0.8); | |
| box-shadow: 0 4px 15px rgba(0,0,0,0.05); | |
| height: 100%; | |
| display: flex; | |
| flex-direction: column; | |
| }} | |
| .output-content {{ | |
| font-size: 1.1em; | |
| line-height: 1.6; | |
| color: #1e293b; | |
| flex-grow: 1; | |
| }} | |
| .btn-generate {{ | |
| background: {button_gradient} !important; | |
| color: white !important; | |
| border-radius: 8px !important; | |
| padding: 12px 24px !important; | |
| font-weight: 600 !important; | |
| font-size: 1.05em !important; | |
| border: none !important; | |
| box-shadow: 0 4px 12px {button_shadow} !important; | |
| transition: all 0.3s ease !important; | |
| width: 100% !important; | |
| margin-top: 15px; | |
| }} | |
| .btn-generate:hover {{ | |
| background: {button_hover_gradient} !important; | |
| box-shadow: 0 6px 16px {button_shadow} !important; | |
| transform: translateY(-2px) !important; | |
| }} | |
| .footer {{ | |
| text-align: center; | |
| margin-top: 30px; | |
| color: #64748b; | |
| font-size: 0.9em; | |
| }} | |
| .model-selector {{ | |
| margin-bottom: 15px; | |
| }} | |
| .input-icon {{ | |
| font-size: 1.5em; | |
| margin-right: 8px; | |
| color: {icon_color}; | |
| }} | |
| .divider {{ | |
| border-top: 1px solid #e2e8f0; | |
| margin: 15px 0; | |
| }} | |
| .input-section {{ | |
| height: 100%; | |
| }} | |
| .result-heading {{ | |
| margin-bottom: 15px; | |
| color: {primary_color}; | |
| background: {header_background}; /* Add background to result header */ | |
| padding: 5px 10px; | |
| border-radius: 6px; | |
| display: inline-block; | |
| }} | |
| """ | |
| # Create Blocks interface with improved structure and parallel layout | |
| with gr.Blocks(css=custom_css) as iface: | |
| with gr.Group(): | |
| icon = "π·" if color_theme == "blue" else "πΆ" | |
| app_name = "OmniPhi Blue" if color_theme == "blue" else "OmniPhi Orange" | |
| gr.Markdown( | |
| f""" | |
| <div class="app-header"> | |
| <div class="app-title">{icon} {app_name}</div> | |
| <div class="app-description">Advanced Multi-Modal AI with BLIP or Custom Model Integration. Upload an image and provide instructions through text or voice to generate detailed descriptions.</div> | |
| </div> | |
| """ | |
| ) | |
| # Main content in a 2-column layout (inputs and output side by side) | |
| with gr.Row(): | |
| # Left column for all inputs | |
| with gr.Column(scale=3): | |
| with gr.Group(): | |
| # Image upload card | |
| with gr.Group(): | |
| gr.Markdown('<span class="input-icon">πΌοΈ</span><span class="input-label">Upload Image</span>') | |
| image_input = gr.Image( | |
| type="pil", | |
| label=None | |
| ) | |
| # Text and voice input card | |
| with gr.Group(): | |
| gr.Markdown('<span class="input-icon">π¬</span><span class="input-label">Text Instruction</span>') | |
| text_input = gr.Textbox( | |
| label=None, | |
| placeholder="e.g., Describe this image in detail, focusing on the environment...", | |
| lines=3 | |
| ) | |
| gr.Markdown('<div class="divider"></div>') | |
| gr.Markdown('<span class="input-icon">ποΈ</span><span class="input-label">Voice Instruction (optional)</span>') | |
| audio_input = gr.Audio( | |
| type="microphone", | |
| label=None | |
| ) | |
| gr.Markdown('<div class="divider"></div>') | |
| gr.Markdown('<span class="input-icon">βοΈ</span><span class="input-label">Model Selection</span>') | |
| with gr.Group(): | |
| model_choice = gr.Radio( | |
| choices=["BLIP", "OmniPhi"], | |
| value="BLIP", | |
| label=None, | |
| interactive=True | |
| ) | |
| submit_btn = gr.Button("Generate Description") | |
| # Right column for output | |
| with gr.Column(scale=2): | |
| with gr.Group(): | |
| gr.Markdown('<span class="input-icon">β¨</span><span class="input-label result-heading">Generated Description</span>') | |
| output = gr.Textbox( | |
| label=None, | |
| lines=12, | |
| placeholder="Your description will appear here after generation..." | |
| ) | |
| # Footer | |
| gr.Markdown( | |
| f""" | |
| <div class="footer"> | |
| Powered by OmniPhi Technology β’ Upload your image and provide instructions through text or voice | |
| </div> | |
| """ | |
| ) | |
| # Connect the button to the function | |
| submit_btn.click( | |
| fn=generate_fn, | |
| inputs=[image_input, text_input, audio_input, model_choice], | |
| outputs=output | |
| ) | |
| return iface | |
| # Main execution | |
| if __name__ == "__main__": | |
| # Load models | |
| transcriber = initialize_transcriber(WHISPER_MODEL) | |
| blip_model, blip_processor = load_blip(BLIP_MODEL, DEVICE, TORCH_DTYPE) | |
| clip_model, clip_processor = load_clip(CLIP_MODEL, DEVICE, TORCH_DTYPE) | |
| omniphi_model, omniphi_tokenizer = load_omniphi(CHECKPOINT_DIR, PHI_MODEL, CLIP_MODEL, DEVICE) | |
| # Define generate function | |
| generate_fn = lambda image, text_prompt, audio, model_choice: generate_description( | |
| image, text_prompt, audio, model_choice, transcriber, blip_model, blip_processor, | |
| clip_model, clip_processor, omniphi_model, omniphi_tokenizer, DEVICE | |
| ) | |
| # Launch Gradio interface | |
| iface = create_gradio_interface(generate_fn, color_theme=COLOR_THEME) | |
| iface.launch(server_name="0.0.0.0", server_port=7860) |