Spaces:
Sleeping
Sleeping
| import spaces | |
| from install_flsh_attn import attn_implementation, dtype | |
| from transformers import ( | |
| AutoImageProcessor, | |
| AutoModelForCausalLM, | |
| ) | |
| import gradio as gr | |
| import torch | |
| from accelerate import Accelerator | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import zipfile | |
| import io | |
| import json | |
| DEVICE = Accelerator().device | |
| MODEL_NAME = "qihoo360/fg-clip2-so400m" | |
| BATCH_SIZE = 128 | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| trust_remote_code=True, | |
| dtype=dtype, | |
| attn_implementation=attn_implementation, | |
| ).to(DEVICE) | |
| image_processor = AutoImageProcessor.from_pretrained(MODEL_NAME) | |
| def determine_max_value(image): | |
| """Determine max_num_patches based on image size.""" | |
| w, h = image.size | |
| max_val = (w // 16) * (h // 16) | |
| if max_val > 784: | |
| return 1024 | |
| elif max_val > 576: | |
| return 784 | |
| elif max_val > 256: | |
| return 576 | |
| elif max_val > 128: | |
| return 256 | |
| else: | |
| return 128 | |
| def generate_image_embeddings(zip_file): | |
| """ | |
| Generate embeddings from images in a zip file. | |
| Args: | |
| zip_file: Uploaded zip file containing images | |
| Returns: | |
| Tuple of (embeddings as numpy file, status message) | |
| """ | |
| try: | |
| # Extract images from zip | |
| images = [] | |
| print(f"Extracting images from zip file: {zip_file.name}") | |
| with zipfile.ZipFile(zip_file.name, "r") as zip_ref: | |
| for file_info in zip_ref.filelist: | |
| if file_info.filename.lower().endswith( | |
| (".png", ".jpg", ".jpeg", ".bmp", ".webp") | |
| ): | |
| with zip_ref.open(file_info) as img_file: | |
| img = Image.open(io.BytesIO(img_file.read())).convert("RGB") | |
| images.append(img) | |
| print(f"Extracted {len(images)} images from zip file") | |
| if len(images) == 0: | |
| return None, "β No valid images found in the zip file" | |
| # Generate embeddings with batching | |
| embeddings = [] | |
| print(f"Generating embeddings for {len(images)} images...") | |
| with torch.no_grad(): | |
| for i in range(0, len(images), BATCH_SIZE): | |
| batch = images[i : i + BATCH_SIZE] | |
| print( | |
| f"Processing batch {i // BATCH_SIZE + 1}/{(len(images) + BATCH_SIZE - 1) // BATCH_SIZE} ({len(batch)} images)" | |
| ) | |
| # Use the same max_num_patches for all images in batch | |
| max_patches = max(determine_max_value(img) for img in batch) | |
| image_input = image_processor( | |
| images=batch, | |
| max_num_patches=max_patches, | |
| return_tensors="pt", | |
| ).to(DEVICE) | |
| image_features = model.get_image_features(**image_input) | |
| # Normalize the embeddings | |
| normalized_features = image_features / image_features.norm( | |
| dim=-1, keepdim=True | |
| ) | |
| embeddings.append( | |
| normalized_features.to(dtype=torch.float32).cpu().numpy() | |
| ) | |
| embeddings = np.vstack(embeddings) | |
| print(f"Embeddings shape: {embeddings.shape}") | |
| # Create JSON output | |
| result = json.dumps( | |
| { | |
| "embeddings": embeddings.tolist(), | |
| "shape": list(embeddings.shape), | |
| "count": len(images), | |
| }, | |
| indent=2, | |
| ) | |
| message = f"β Successfully generated embeddings for {len(images)} images\nShape: {embeddings.shape}" | |
| print(message) | |
| return result, message | |
| except Exception as e: | |
| error_msg = f"β Error: {str(e)}" | |
| print(error_msg) | |
| return None, error_msg | |
| def extract_frames(video_path: str, fps: int = 4): | |
| """ | |
| Extract frames from video at specified fps. | |
| Args: | |
| video_path: Path to the video file | |
| fps: Frames per second to sample | |
| Returns: | |
| List of PIL Images | |
| """ | |
| cap = cv2.VideoCapture(video_path) | |
| video_fps = cap.get(cv2.CAP_PROP_FPS) | |
| frame_interval = int(round(video_fps) / fps) | |
| frames = [] | |
| frame_count = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if frame_count % frame_interval == 0: | |
| # Convert BGR to RGB | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| pil_image = Image.fromarray(frame_rgb) | |
| frames.append(pil_image) | |
| frame_count += 1 | |
| cap.release() | |
| return frames | |
| def generate_video_embeddings(video_path, fps): | |
| """ | |
| Generate embeddings from video frames. | |
| Args: | |
| video_path: Path to video file (str) | |
| fps: Frames per second to extract | |
| Returns: | |
| Tuple of (embeddings as numpy file, status message) | |
| """ | |
| try: | |
| # Extract frames | |
| print(f"Extracting frames from video: {video_path} at {fps} fps") | |
| frames = extract_frames(video_path, fps) | |
| print(f"Extracted {len(frames)} frames from video") | |
| if len(frames) == 0: | |
| return None, "β No frames could be extracted from the video" | |
| # Generate embeddings with batching | |
| embeddings = [] | |
| print(f"Generating embeddings for {len(frames)} frames...") | |
| with torch.no_grad(): | |
| for i in range(0, len(frames), BATCH_SIZE): | |
| batch = frames[i : i + BATCH_SIZE] | |
| print( | |
| f"Processing batch {i // BATCH_SIZE + 1}/{(len(frames) + BATCH_SIZE - 1) // BATCH_SIZE} ({len(batch)} frames)" | |
| ) | |
| # Use the same max_num_patches for all frames in batch | |
| max_patches = max(determine_max_value(frame) for frame in batch) | |
| image_input = image_processor( | |
| images=batch, | |
| max_num_patches=max_patches, | |
| return_tensors="pt", | |
| ).to(DEVICE) | |
| image_features = model.get_image_features(**image_input) | |
| # Normalize the embeddings | |
| normalized_features = image_features / image_features.norm( | |
| dim=-1, keepdim=True | |
| ) | |
| embeddings.append( | |
| normalized_features.to(dtype=torch.float32).cpu().numpy() | |
| ) | |
| embeddings = np.vstack(embeddings) | |
| print(f"Embeddings shape: {embeddings.shape}") | |
| # Create JSON output | |
| result = json.dumps( | |
| { | |
| "embeddings": embeddings.tolist(), | |
| "shape": list(embeddings.shape), | |
| "count": len(frames), | |
| "fps": fps, | |
| }, | |
| indent=2, | |
| ) | |
| message = f"β Successfully generated embeddings for {len(frames)} frames (extracted at {fps} fps)\nShape: {embeddings.shape}" | |
| print(message) | |
| return result, message | |
| except Exception as e: | |
| error_msg = f"β Error: {str(e)}" | |
| print(error_msg) | |
| return None, error_msg | |
| # Create Gradio interface | |
| with gr.Blocks(title="Video & Image Embedding Generator") as demo: | |
| gr.Markdown("# π¬ Video & Image Embedding Generator") | |
| gr.Markdown(f"Generate embeddings using **{MODEL_NAME}** model") | |
| with gr.Tab("π¦ Images from ZIP"): | |
| gr.Markdown("Upload a ZIP file containing images to generate embeddings") | |
| with gr.Row(): | |
| with gr.Column(): | |
| zip_input = gr.File(label="Upload ZIP file", file_types=[".zip"]) | |
| img_submit_btn = gr.Button("Generate Embeddings", variant="primary") | |
| with gr.Column(): | |
| img_output = gr.JSON(label="Embeddings (JSON)") | |
| img_status = gr.Textbox(label="Status", lines=3) | |
| img_submit_btn.click( | |
| fn=generate_image_embeddings, | |
| inputs=[zip_input], | |
| outputs=[img_output, img_status], | |
| ) | |
| with gr.Tab("π₯ Video Frames"): | |
| gr.Markdown( | |
| "Upload a video and specify FPS to extract frames and generate embeddings" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_input = gr.Video(label="Upload Video") | |
| fps_input = gr.Slider( | |
| minimum=1, | |
| maximum=30, | |
| value=4, | |
| step=1, | |
| label="Frames per Second (FPS)", | |
| ) | |
| vid_submit_btn = gr.Button("Generate Embeddings", variant="primary") | |
| with gr.Column(): | |
| vid_output = gr.JSON(label="Embeddings (JSON)") | |
| vid_status = gr.Textbox(label="Status", lines=3) | |
| def handle_video_upload(video_file, fps): | |
| if video_file is None: | |
| return None, "β Please upload a video file" | |
| return generate_video_embeddings( | |
| video_file.name if hasattr(video_file, "name") else video_file, fps | |
| ) | |
| vid_submit_btn.click( | |
| fn=handle_video_upload, | |
| inputs=[video_input, fps_input], | |
| outputs=[vid_output, vid_status], | |
| ) | |
| gr.Markdown( | |
| """ | |
| ### π Notes: | |
| - Images in ZIP: Supports PNG, JPG, JPEG, BMP, WEBP formats | |
| - Video: Supports common video formats (MP4, AVI, MOV, etc.) | |
| - Output: JSON object containing normalized embeddings with metadata | |
| - Structure: `{"embeddings": [...], "shape": [n, dim], "count": n, "fps": f}` | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |