File size: 9,585 Bytes
7c94b61
4b20c0d
7c94b61
 
 
 
 
 
 
 
 
 
 
 
92a946e
7c94b61
 
 
4b20c0d
7c94b61
 
4b20c0d
 
 
 
 
 
7c94b61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ecb18c
7c94b61
 
 
 
 
 
 
 
 
9ecb18c
 
7c94b61
 
 
6510c49
7c94b61
9ecb18c
7c94b61
6510c49
 
 
 
 
 
 
 
 
7c94b61
6510c49
 
7c94b61
 
6510c49
7c94b61
6510c49
 
7c94b61
 
e7b5ffc
 
 
7c94b61
 
9ecb18c
7c94b61
92a946e
 
 
 
 
 
 
 
 
7c94b61
 
9ecb18c
92a946e
7c94b61
 
9ecb18c
 
 
7c94b61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6510c49
7c94b61
 
 
 
6510c49
7c94b61
 
 
 
 
 
 
6510c49
 
9ecb18c
7c94b61
 
 
 
6510c49
7c94b61
9ecb18c
7c94b61
6510c49
 
 
 
 
 
 
 
 
7c94b61
6510c49
 
7c94b61
 
6510c49
7c94b61
6510c49
 
7c94b61
 
e7b5ffc
 
 
7c94b61
 
9ecb18c
7c94b61
92a946e
 
 
 
 
 
 
 
 
 
7c94b61
 
9ecb18c
92a946e
7c94b61
 
9ecb18c
 
 
7c94b61
 
 
 
 
 
 
 
 
 
 
 
 
 
92a946e
7c94b61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92a946e
7c94b61
 
6510c49
 
 
 
 
 
 
7c94b61
6510c49
7c94b61
 
 
 
 
 
 
 
 
92a946e
 
7c94b61
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
import spaces
from install_flsh_attn import attn_implementation, dtype
from transformers import (
    AutoImageProcessor,
    AutoModelForCausalLM,
)
import gradio as gr
import torch
from accelerate import Accelerator
import numpy as np
import cv2
from PIL import Image
import zipfile
import io
import json

DEVICE = Accelerator().device
MODEL_NAME = "qihoo360/fg-clip2-so400m"
BATCH_SIZE = 128


model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    dtype=dtype,
    attn_implementation=attn_implementation,
).to(DEVICE)
image_processor = AutoImageProcessor.from_pretrained(MODEL_NAME)


def determine_max_value(image):
    """Determine max_num_patches based on image size."""
    w, h = image.size
    max_val = (w // 16) * (h // 16)
    if max_val > 784:
        return 1024
    elif max_val > 576:
        return 784
    elif max_val > 256:
        return 576
    elif max_val > 128:
        return 256
    else:
        return 128


@spaces.GPU
def generate_image_embeddings(zip_file):
    """
    Generate embeddings from images in a zip file.

    Args:
        zip_file: Uploaded zip file containing images

    Returns:
        Tuple of (embeddings as numpy file, status message)
    """
    try:
        # Extract images from zip
        images = []
        print(f"Extracting images from zip file: {zip_file.name}")
        with zipfile.ZipFile(zip_file.name, "r") as zip_ref:
            for file_info in zip_ref.filelist:
                if file_info.filename.lower().endswith(
                    (".png", ".jpg", ".jpeg", ".bmp", ".webp")
                ):
                    with zip_ref.open(file_info) as img_file:
                        img = Image.open(io.BytesIO(img_file.read())).convert("RGB")
                        images.append(img)

        print(f"Extracted {len(images)} images from zip file")

        if len(images) == 0:
            return None, "❌ No valid images found in the zip file"

        # Generate embeddings with batching
        embeddings = []
        print(f"Generating embeddings for {len(images)} images...")
        with torch.no_grad():
            for i in range(0, len(images), BATCH_SIZE):
                batch = images[i : i + BATCH_SIZE]
                print(
                    f"Processing batch {i // BATCH_SIZE + 1}/{(len(images) + BATCH_SIZE - 1) // BATCH_SIZE} ({len(batch)} images)"
                )

                # Use the same max_num_patches for all images in batch
                max_patches = max(determine_max_value(img) for img in batch)

                image_input = image_processor(
                    images=batch,
                    max_num_patches=max_patches,
                    return_tensors="pt",
                ).to(DEVICE)
                image_features = model.get_image_features(**image_input)

                # Normalize the embeddings
                normalized_features = image_features / image_features.norm(
                    dim=-1, keepdim=True
                )
                embeddings.append(
                    normalized_features.to(dtype=torch.float32).cpu().numpy()
                )

        embeddings = np.vstack(embeddings)
        print(f"Embeddings shape: {embeddings.shape}")

        # Create JSON output
        result = json.dumps(
            {
                "embeddings": embeddings.tolist(),
                "shape": list(embeddings.shape),
                "count": len(images),
            },
            indent=2,
        )

        message = f"βœ… Successfully generated embeddings for {len(images)} images\nShape: {embeddings.shape}"
        print(message)
        return result, message

    except Exception as e:
        error_msg = f"❌ Error: {str(e)}"
        print(error_msg)
        return None, error_msg


def extract_frames(video_path: str, fps: int = 4):
    """
    Extract frames from video at specified fps.

    Args:
        video_path: Path to the video file
        fps: Frames per second to sample

    Returns:
        List of PIL Images
    """
    cap = cv2.VideoCapture(video_path)
    video_fps = cap.get(cv2.CAP_PROP_FPS)
    frame_interval = int(round(video_fps) / fps)

    frames = []
    frame_count = 0

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        if frame_count % frame_interval == 0:
            # Convert BGR to RGB
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            pil_image = Image.fromarray(frame_rgb)
            frames.append(pil_image)

        frame_count += 1

    cap.release()
    return frames


@spaces.GPU
def generate_video_embeddings(video_path, fps):
    """
    Generate embeddings from video frames.

    Args:
        video_path: Path to video file (str)
        fps: Frames per second to extract

    Returns:
        Tuple of (embeddings as numpy file, status message)
    """
    try:
        # Extract frames
        print(f"Extracting frames from video: {video_path} at {fps} fps")
        frames = extract_frames(video_path, fps)
        print(f"Extracted {len(frames)} frames from video")

        if len(frames) == 0:
            return None, "❌ No frames could be extracted from the video"

        # Generate embeddings with batching
        embeddings = []
        print(f"Generating embeddings for {len(frames)} frames...")
        with torch.no_grad():
            for i in range(0, len(frames), BATCH_SIZE):
                batch = frames[i : i + BATCH_SIZE]
                print(
                    f"Processing batch {i // BATCH_SIZE + 1}/{(len(frames) + BATCH_SIZE - 1) // BATCH_SIZE} ({len(batch)} frames)"
                )

                # Use the same max_num_patches for all frames in batch
                max_patches = max(determine_max_value(frame) for frame in batch)

                image_input = image_processor(
                    images=batch,
                    max_num_patches=max_patches,
                    return_tensors="pt",
                ).to(DEVICE)
                image_features = model.get_image_features(**image_input)

                # Normalize the embeddings
                normalized_features = image_features / image_features.norm(
                    dim=-1, keepdim=True
                )
                embeddings.append(
                    normalized_features.to(dtype=torch.float32).cpu().numpy()
                )

        embeddings = np.vstack(embeddings)
        print(f"Embeddings shape: {embeddings.shape}")

        # Create JSON output
        result = json.dumps(
            {
                "embeddings": embeddings.tolist(),
                "shape": list(embeddings.shape),
                "count": len(frames),
                "fps": fps,
            },
            indent=2,
        )

        message = f"βœ… Successfully generated embeddings for {len(frames)} frames (extracted at {fps} fps)\nShape: {embeddings.shape}"
        print(message)
        return result, message

    except Exception as e:
        error_msg = f"❌ Error: {str(e)}"
        print(error_msg)
        return None, error_msg


# Create Gradio interface
with gr.Blocks(title="Video & Image Embedding Generator") as demo:
    gr.Markdown("# 🎬 Video & Image Embedding Generator")
    gr.Markdown(f"Generate embeddings using **{MODEL_NAME}** model")

    with gr.Tab("πŸ“¦ Images from ZIP"):
        gr.Markdown("Upload a ZIP file containing images to generate embeddings")
        with gr.Row():
            with gr.Column():
                zip_input = gr.File(label="Upload ZIP file", file_types=[".zip"])
                img_submit_btn = gr.Button("Generate Embeddings", variant="primary")
            with gr.Column():
                img_output = gr.JSON(label="Embeddings (JSON)")
                img_status = gr.Textbox(label="Status", lines=3)

        img_submit_btn.click(
            fn=generate_image_embeddings,
            inputs=[zip_input],
            outputs=[img_output, img_status],
        )

    with gr.Tab("πŸŽ₯ Video Frames"):
        gr.Markdown(
            "Upload a video and specify FPS to extract frames and generate embeddings"
        )
        with gr.Row():
            with gr.Column():
                video_input = gr.Video(label="Upload Video")
                fps_input = gr.Slider(
                    minimum=1,
                    maximum=30,
                    value=4,
                    step=1,
                    label="Frames per Second (FPS)",
                )
                vid_submit_btn = gr.Button("Generate Embeddings", variant="primary")
            with gr.Column():
                vid_output = gr.JSON(label="Embeddings (JSON)")
                vid_status = gr.Textbox(label="Status", lines=3)

        def handle_video_upload(video_file, fps):
            if video_file is None:
                return None, "❌ Please upload a video file"
            return generate_video_embeddings(
                video_file.name if hasattr(video_file, "name") else video_file, fps
            )

        vid_submit_btn.click(
            fn=handle_video_upload,
            inputs=[video_input, fps_input],
            outputs=[vid_output, vid_status],
        )

    gr.Markdown(
        """
    ### πŸ“ Notes:
    - Images in ZIP: Supports PNG, JPG, JPEG, BMP, WEBP formats
    - Video: Supports common video formats (MP4, AVI, MOV, etc.)
    - Output: JSON object containing normalized embeddings with metadata
    - Structure: `{"embeddings": [...], "shape": [n, dim], "count": n, "fps": f}`
    """
    )


if __name__ == "__main__":
    demo.launch()