tmp-service / app.py
JacobLinCool's picture
Update app.py
e7b5ffc verified
import spaces
from install_flsh_attn import attn_implementation, dtype
from transformers import (
AutoImageProcessor,
AutoModelForCausalLM,
)
import gradio as gr
import torch
from accelerate import Accelerator
import numpy as np
import cv2
from PIL import Image
import zipfile
import io
import json
DEVICE = Accelerator().device
MODEL_NAME = "qihoo360/fg-clip2-so400m"
BATCH_SIZE = 128
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
dtype=dtype,
attn_implementation=attn_implementation,
).to(DEVICE)
image_processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
def determine_max_value(image):
"""Determine max_num_patches based on image size."""
w, h = image.size
max_val = (w // 16) * (h // 16)
if max_val > 784:
return 1024
elif max_val > 576:
return 784
elif max_val > 256:
return 576
elif max_val > 128:
return 256
else:
return 128
@spaces.GPU
def generate_image_embeddings(zip_file):
"""
Generate embeddings from images in a zip file.
Args:
zip_file: Uploaded zip file containing images
Returns:
Tuple of (embeddings as numpy file, status message)
"""
try:
# Extract images from zip
images = []
print(f"Extracting images from zip file: {zip_file.name}")
with zipfile.ZipFile(zip_file.name, "r") as zip_ref:
for file_info in zip_ref.filelist:
if file_info.filename.lower().endswith(
(".png", ".jpg", ".jpeg", ".bmp", ".webp")
):
with zip_ref.open(file_info) as img_file:
img = Image.open(io.BytesIO(img_file.read())).convert("RGB")
images.append(img)
print(f"Extracted {len(images)} images from zip file")
if len(images) == 0:
return None, "❌ No valid images found in the zip file"
# Generate embeddings with batching
embeddings = []
print(f"Generating embeddings for {len(images)} images...")
with torch.no_grad():
for i in range(0, len(images), BATCH_SIZE):
batch = images[i : i + BATCH_SIZE]
print(
f"Processing batch {i // BATCH_SIZE + 1}/{(len(images) + BATCH_SIZE - 1) // BATCH_SIZE} ({len(batch)} images)"
)
# Use the same max_num_patches for all images in batch
max_patches = max(determine_max_value(img) for img in batch)
image_input = image_processor(
images=batch,
max_num_patches=max_patches,
return_tensors="pt",
).to(DEVICE)
image_features = model.get_image_features(**image_input)
# Normalize the embeddings
normalized_features = image_features / image_features.norm(
dim=-1, keepdim=True
)
embeddings.append(
normalized_features.to(dtype=torch.float32).cpu().numpy()
)
embeddings = np.vstack(embeddings)
print(f"Embeddings shape: {embeddings.shape}")
# Create JSON output
result = json.dumps(
{
"embeddings": embeddings.tolist(),
"shape": list(embeddings.shape),
"count": len(images),
},
indent=2,
)
message = f"βœ… Successfully generated embeddings for {len(images)} images\nShape: {embeddings.shape}"
print(message)
return result, message
except Exception as e:
error_msg = f"❌ Error: {str(e)}"
print(error_msg)
return None, error_msg
def extract_frames(video_path: str, fps: int = 4):
"""
Extract frames from video at specified fps.
Args:
video_path: Path to the video file
fps: Frames per second to sample
Returns:
List of PIL Images
"""
cap = cv2.VideoCapture(video_path)
video_fps = cap.get(cv2.CAP_PROP_FPS)
frame_interval = int(round(video_fps) / fps)
frames = []
frame_count = 0
while True:
ret, frame = cap.read()
if not ret:
break
if frame_count % frame_interval == 0:
# Convert BGR to RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(frame_rgb)
frames.append(pil_image)
frame_count += 1
cap.release()
return frames
@spaces.GPU
def generate_video_embeddings(video_path, fps):
"""
Generate embeddings from video frames.
Args:
video_path: Path to video file (str)
fps: Frames per second to extract
Returns:
Tuple of (embeddings as numpy file, status message)
"""
try:
# Extract frames
print(f"Extracting frames from video: {video_path} at {fps} fps")
frames = extract_frames(video_path, fps)
print(f"Extracted {len(frames)} frames from video")
if len(frames) == 0:
return None, "❌ No frames could be extracted from the video"
# Generate embeddings with batching
embeddings = []
print(f"Generating embeddings for {len(frames)} frames...")
with torch.no_grad():
for i in range(0, len(frames), BATCH_SIZE):
batch = frames[i : i + BATCH_SIZE]
print(
f"Processing batch {i // BATCH_SIZE + 1}/{(len(frames) + BATCH_SIZE - 1) // BATCH_SIZE} ({len(batch)} frames)"
)
# Use the same max_num_patches for all frames in batch
max_patches = max(determine_max_value(frame) for frame in batch)
image_input = image_processor(
images=batch,
max_num_patches=max_patches,
return_tensors="pt",
).to(DEVICE)
image_features = model.get_image_features(**image_input)
# Normalize the embeddings
normalized_features = image_features / image_features.norm(
dim=-1, keepdim=True
)
embeddings.append(
normalized_features.to(dtype=torch.float32).cpu().numpy()
)
embeddings = np.vstack(embeddings)
print(f"Embeddings shape: {embeddings.shape}")
# Create JSON output
result = json.dumps(
{
"embeddings": embeddings.tolist(),
"shape": list(embeddings.shape),
"count": len(frames),
"fps": fps,
},
indent=2,
)
message = f"βœ… Successfully generated embeddings for {len(frames)} frames (extracted at {fps} fps)\nShape: {embeddings.shape}"
print(message)
return result, message
except Exception as e:
error_msg = f"❌ Error: {str(e)}"
print(error_msg)
return None, error_msg
# Create Gradio interface
with gr.Blocks(title="Video & Image Embedding Generator") as demo:
gr.Markdown("# 🎬 Video & Image Embedding Generator")
gr.Markdown(f"Generate embeddings using **{MODEL_NAME}** model")
with gr.Tab("πŸ“¦ Images from ZIP"):
gr.Markdown("Upload a ZIP file containing images to generate embeddings")
with gr.Row():
with gr.Column():
zip_input = gr.File(label="Upload ZIP file", file_types=[".zip"])
img_submit_btn = gr.Button("Generate Embeddings", variant="primary")
with gr.Column():
img_output = gr.JSON(label="Embeddings (JSON)")
img_status = gr.Textbox(label="Status", lines=3)
img_submit_btn.click(
fn=generate_image_embeddings,
inputs=[zip_input],
outputs=[img_output, img_status],
)
with gr.Tab("πŸŽ₯ Video Frames"):
gr.Markdown(
"Upload a video and specify FPS to extract frames and generate embeddings"
)
with gr.Row():
with gr.Column():
video_input = gr.Video(label="Upload Video")
fps_input = gr.Slider(
minimum=1,
maximum=30,
value=4,
step=1,
label="Frames per Second (FPS)",
)
vid_submit_btn = gr.Button("Generate Embeddings", variant="primary")
with gr.Column():
vid_output = gr.JSON(label="Embeddings (JSON)")
vid_status = gr.Textbox(label="Status", lines=3)
def handle_video_upload(video_file, fps):
if video_file is None:
return None, "❌ Please upload a video file"
return generate_video_embeddings(
video_file.name if hasattr(video_file, "name") else video_file, fps
)
vid_submit_btn.click(
fn=handle_video_upload,
inputs=[video_input, fps_input],
outputs=[vid_output, vid_status],
)
gr.Markdown(
"""
### πŸ“ Notes:
- Images in ZIP: Supports PNG, JPG, JPEG, BMP, WEBP formats
- Video: Supports common video formats (MP4, AVI, MOV, etc.)
- Output: JSON object containing normalized embeddings with metadata
- Structure: `{"embeddings": [...], "shape": [n, dim], "count": n, "fps": f}`
"""
)
if __name__ == "__main__":
demo.launch()