from __future__ import annotations import argparse import json import logging from pathlib import Path from typing import Iterable import numpy as np import requests import torch from PIL import Image from io import BytesIO from model import ClipScorer, ImageEntry, load_image_entries logging.basicConfig(level=logging.INFO) logger = logging.getLogger("precompute_embeddings") def download_image(url: str) -> Image.Image: headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/115.0.0.0 Safari/537.36" } response = requests.get(url, headers=headers) response.raise_for_status() return Image.open(BytesIO(response.content)) def save_embedding(path: Path, embedding: torch.Tensor) -> None: path.parent.mkdir(parents=True, exist_ok=True) array = embedding.detach().cpu().numpy().astype(np.float32) suffix = path.suffix.lower() if suffix == ".json": data = { "embedding": array.tolist(), } with path.open("w", encoding="utf-8") as handle: json.dump(data, handle, ensure_ascii=False, separators=(",", ":")) else: np.save(path, array) def compute_embeddings(entries: Iterable[ImageEntry], model_name: str = "jinaai/jina-clip-v2") -> None: scorer = ClipScorer(model_name=model_name) processed = 0 for entry in entries: if entry.clip_model != model_name: logger.info( "Überspringe %s, da clip_model %s nicht zum Modell %s passt.", entry.image_id, entry.clip_model, model_name, ) continue if entry.embedding_path.exists(): logger.info("Embedding für %s existiert bereits (%s)", entry.image_id, entry.embedding_path) continue logger.info("Lade Bild %s", entry.image_id) image = download_image(entry.image_url) features = scorer.encode_image(image) save_embedding(entry.embedding_path, features) processed += 1 logger.info("Embedding für %s gespeichert (%s)", entry.image_id, entry.embedding_path) logger.info("Fertig. %d Embeddings erzeugt.", processed) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Berechnet CLIP-Embeddings für die Bilder aus images.csv") parser.add_argument("--csv", default="images.csv", help="Pfad zur images.csv") parser.add_argument( "--model-name", default="jinaai/jina-clip-v2", help="Hugging-Face-Repository des gewünschten CLIP-Modells", ) return parser.parse_args() def main() -> None: args = parse_args() entries = load_image_entries(Path(args.csv)) compute_embeddings(entries, model_name=args.model_name) if __name__ == "__main__": main()