game / model.py
Chrimo's picture
refactor to jinai
948475f
from __future__ import annotations
import csv
import json
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class ImageEntry:
"""Container für Bildmetadaten und Pfade zu Embeddings."""
image_id: str
image_url: str
clip_model: str
embedding_path: Path
def load_image_entries(csv_path: Path | str) -> List[ImageEntry]:
"""Liest die Bildliste aus einer CSV-Datei."""
path = Path(csv_path)
if not path.exists():
raise FileNotFoundError(f"Die Datei {path} existiert nicht.")
entries: List[ImageEntry] = []
with path.open("r", encoding="utf-8") as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
image_id = row.get("image_id") or row.get("id")
image_url = row.get("image_url") or row.get("url")
clip_model = row.get("clip_model") or "jinaai/jina-clip-v2"
embedding_path = row.get("embedding_path") or f"embeddings/{image_id}.npy"
entries.append(
ImageEntry(
image_id=image_id,
image_url=image_url,
clip_model=clip_model,
embedding_path=Path(embedding_path),
)
)
return entries
def similarity_to_score(similarity: float) -> int:
"""Wandelt eine Kosinusähnlichkeit (-1 bis 1) in einen Score von 0 bis 1000 um."""
clipped = max(-1.0, min(1.0, similarity))
score = int(round(((clipped + 1.0) / 2.0) * 1000))
return score
def _require_numpy():
try:
import numpy as np # type: ignore
except ModuleNotFoundError as exc: # pragma: no cover - defensive fallback
raise ModuleNotFoundError("numpy wird benötigt, ist aber nicht installiert.") from exc
return np
def _require_torch():
try:
import torch # type: ignore
except ModuleNotFoundError as exc: # pragma: no cover - defensive fallback
raise ModuleNotFoundError("torch wird benötigt, ist aber nicht installiert.") from exc
return torch
def _require_transformers():
try:
from transformers import AutoModel, AutoProcessor # type: ignore
except ModuleNotFoundError as exc: # pragma: no cover - defensive fallback
raise ModuleNotFoundError("transformers wird benötigt, ist aber nicht installiert.") from exc
return AutoModel, AutoProcessor
class ClipScorer:
"""Wrapper um CLIP für Text-/Bild-Embeddings und Scores."""
def __init__(
self,
model_name: str = "jinaai/jina-clip-v2",
pretrained: Optional[str] = None,
device: Optional[str] = None,
) -> None:
self.model_name = model_name
self.pretrained = pretrained
torch = _require_torch()
AutoModel, AutoProcessor = _require_transformers()
self._torch = torch
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
if pretrained and pretrained != model_name:
logger.warning(
"Der Parameter 'pretrained' (%s) wird für transformers-basierte Modelle ignoriert.",
pretrained,
)
logger.info("Lade CLIP Modell %s auf %s", model_name, self.device)
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
self.model.to(self.device)
self.model.eval()
for parameter in self.model.parameters():
parameter.requires_grad = False
config = getattr(self.model, "config", None)
embedding_dim = None
if config is not None:
embedding_dim = getattr(config, "projection_dim", None)
if embedding_dim is None:
embedding_dim = getattr(config, "hidden_size", None)
self.embedding_dim: Optional[int] = embedding_dim
self._image_embeddings: Dict[str, Any] = {}
def load_precomputed_embeddings(self, entries: Iterable[ImageEntry]) -> None:
"""Lädt Embeddings aus .npy-Dateien und speichert sie intern."""
loaded = 0
for entry in entries:
if entry.clip_model != self.model_name:
logger.warning(
"Überspringe Bild %s: erwartet Modell %s, gefunden %s",
entry.image_id,
self.model_name,
entry.clip_model,
)
continue
if not entry.embedding_path.exists():
raise FileNotFoundError(
f"Embedding-Datei für {entry.image_id} fehlt: {entry.embedding_path}"
)
torch = self._torch
suffix = entry.embedding_path.suffix.lower()
if suffix == ".json":
with entry.embedding_path.open("r", encoding="utf-8") as handle:
payload = json.load(handle)
if isinstance(payload, dict):
values = (
payload.get("embedding")
or payload.get("values")
or payload.get("data")
)
else:
values = payload
if values is None:
raise ValueError(
f"Embedding-Datei {entry.embedding_path} enthält keine Werte."
)
tensor = torch.tensor(values, dtype=torch.float32, device=self.device)
if tensor.ndim > 1:
tensor = tensor.view(-1)
else:
np = _require_numpy()
array = np.load(entry.embedding_path)
if array.ndim > 1:
array = array.squeeze()
tensor = torch.from_numpy(array).to(self.device)
tensor = tensor.to(dtype=torch.float32)
expected_dim = self.embedding_dim
if expected_dim is not None and tensor.shape[-1] != expected_dim:
raise ValueError(
"Embedding-Dimension stimmt nicht mit dem geladenen Modell überein. "
f"Erwartet: {expected_dim}, erhalten: {tensor.shape[-1]} für {entry.image_id}."
)
norm = torch.linalg.norm(tensor)
if norm == 0:
raise ValueError(f"Embedding für {entry.image_id} hat Norm 0.")
tensor = tensor / norm
self._image_embeddings[entry.image_id] = tensor
loaded += 1
if loaded == 0:
raise ValueError("Keine Embeddings konnten geladen werden.")
logger.info("%d Embeddings geladen.", loaded)
def encode_text(self, text: str) -> Any:
torch = self._torch
inputs = self.processor(text=[text], return_tensors="pt", padding=True, truncation=True)
inputs = {key: value.to(self.device) for key, value in inputs.items() if isinstance(value, torch.Tensor)}
with torch.no_grad():
text_features = self.model.get_text_features(**inputs).float()
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features[0]
def encode_image(self, image: Any) -> Any:
torch = self._torch
inputs = self.processor(images=image, return_tensors="pt")
inputs = {key: value.to(self.device) for key, value in inputs.items() if isinstance(value, torch.Tensor)}
with torch.no_grad():
image_features = self.model.get_image_features(**inputs).float()
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
return image_features[0]
def get_image_embedding(self, image_id: str) -> Any:
try:
return self._image_embeddings[image_id]
except KeyError as exc:
raise KeyError(f"Kein Embedding für Bild-ID {image_id} geladen.") from exc
def compute_similarity(self, text_embedding: Any, image_embedding: Any) -> float:
torch = self._torch
similarity = torch.matmul(text_embedding, image_embedding)
return float(similarity.item())
def score_text_for_image(self, text: str, image_id: str) -> tuple[float, int]:
text_embedding = self.encode_text(text)
image_embedding = self.get_image_embedding(image_id)
similarity = self.compute_similarity(text_embedding, image_embedding)
score = similarity_to_score(similarity)
return similarity, score
__all__ = [
"ClipScorer",
"ImageEntry",
"load_image_entries",
"similarity_to_score",
]