import gradio as gr import cv2 import torch from PIL import Image from pathlib import Path from threading import Thread from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer import spaces import time import os from huggingface_hub import login login(token=os.getenv("HUGGING_FACE_HUB_TOKEN")) # model config model_4b_name = "google/gemma-3-4b-it" model_4b = Gemma3ForConditionalGeneration.from_pretrained( model_4b_name, device_map="auto", torch_dtype=torch.bfloat16 ).eval() processor_4b = AutoProcessor.from_pretrained(model_4b_name) # I will add timestamp later def extract_video_frames(video_path, num_frames=8): cap = cv2.VideoCapture(video_path) frames = [] total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) step = max(total_frames // num_frames, 1) for i in range(num_frames): cap.set(cv2.CAP_PROP_POS_FRAMES, i * step) ret, frame = cap.read() if ret: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(Image.fromarray(frame)) cap.release() return frames def format_message(content, files): message_content = [] if content: parts = content.split('') for i, part in enumerate(parts): if part.strip(): message_content.append({"type": "text", "text": part.strip()}) if i < len(parts) - 1 and files: img = Image.open(files.pop(0)) message_content.append({"type": "image", "image": img}) for file in files: file_path = file if isinstance(file, str) else file.name if Path(file_path).suffix.lower() in ['.jpg', '.jpeg', '.png']: img = Image.open(file_path) message_content.append({"type": "image", "image": img}) elif Path(file_path).suffix.lower() in ['.mp4', '.mov']: frames = extract_video_frames(file_path) for frame in frames: message_content.append({"type": "image", "image": frame}) return message_content def format_conversation_history(chat_history): messages = [] current_user_content = [] for item in chat_history: role = item["role"] content = item["content"] if role == "user": if isinstance(content, str): current_user_content.append({"type": "text", "text": content}) elif isinstance(content, list): current_user_content.extend(content) else: current_user_content.append({"type": "text", "text": str(content)}) elif role == "assistant": if current_user_content: messages.append({"role": "user", "content": current_user_content}) current_user_content = [] messages.append({"role": "assistant", "content": [{"type": "text", "text": str(content)}]}) if current_user_content: messages.append({"role": "user", "content": current_user_content}) return messages @spaces.GPU(duration=120) def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty): if isinstance(input_data, dict) and "text" in input_data: text = input_data["text"] files = input_data.get("files", []) else: text = str(input_data) files = [] new_message_content = format_message(text, files) new_message = {"role": "user", "content": new_message_content} system_message = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}] if system_prompt else [] processed_history = format_conversation_history(chat_history) messages = system_message + processed_history if messages and messages[-1]["role"] == "user": messages[-1]["content"].extend(new_message["content"]) else: messages.append(new_message) model = model_4b processor = processor_4b inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True ).to(model.device) # Use the tokenizer for streaming decode streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() outputs = [] for text in streamer: outputs.append(text) yield "".join(outputs) demo = gr.ChatInterface( fn=generate_response, additional_inputs=[ gr.Slider(label="Max new tokens", minimum=100, maximum=2000, step=1, value=512), gr.Textbox( label="System Prompt", value="You are a friendly chatbot. ", lines=4, placeholder="Change system prompt" ), gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7), gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9), gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50), gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.0), ], examples=[ [{"files": ["images/Stop_sign_UAE.jpg"], "text": "اقرأ النص على اللافتة وترجمه إلى الإنجليزية."}], [{"files": ["./images/The_Rangoli_of_Lights.jpg"],"text": "इस रंगोली और दीयों का सांस्कृतिक महत्व एक–दो वाक्यों में समझाइए।"}], [{"files": ["./images/A_dallah_a_traditional_Arabic_coffee_pot_with_cups_and_coffee_beans.jpg"], "text": "ما اسم هذا الإناء وما رمزيته الثقافية في الجزيرة العربية؟"}], [{"files": ["./images/Indian_Festival_of_colors_Holi.jpg"], "text": "यह कौन‑सा त्योहार है? दो वाक्यों में बताइए कि लोग रंग क्यों लगाते हैं।"}] ], cache_examples=False, type="messages", description=""" # Finetuned Model """, fill_height=True, textbox=gr.MultimodalTextbox( label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="Type your message or upload media" ), stop_btn="Stop Generation", multimodal=True, theme=gr.themes.Soft(), ) if __name__ == "__main__": demo.launch()