game / precompute_embeddings.py
Chrimo's picture
refactor to jinai
948475f
from __future__ import annotations
import argparse
import json
import logging
from pathlib import Path
from typing import Iterable
import numpy as np
import requests
import torch
from PIL import Image
from io import BytesIO
from model import ClipScorer, ImageEntry, load_image_entries
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("precompute_embeddings")
def download_image(url: str) -> Image.Image:
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/115.0.0.0 Safari/537.36"
}
response = requests.get(url, headers=headers)
response.raise_for_status()
return Image.open(BytesIO(response.content))
def save_embedding(path: Path, embedding: torch.Tensor) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
array = embedding.detach().cpu().numpy().astype(np.float32)
suffix = path.suffix.lower()
if suffix == ".json":
data = {
"embedding": array.tolist(),
}
with path.open("w", encoding="utf-8") as handle:
json.dump(data, handle, ensure_ascii=False, separators=(",", ":"))
else:
np.save(path, array)
def compute_embeddings(entries: Iterable[ImageEntry], model_name: str = "jinaai/jina-clip-v2") -> None:
scorer = ClipScorer(model_name=model_name)
processed = 0
for entry in entries:
if entry.clip_model != model_name:
logger.info(
"Überspringe %s, da clip_model %s nicht zum Modell %s passt.",
entry.image_id,
entry.clip_model,
model_name,
)
continue
if entry.embedding_path.exists():
logger.info("Embedding für %s existiert bereits (%s)", entry.image_id, entry.embedding_path)
continue
logger.info("Lade Bild %s", entry.image_id)
image = download_image(entry.image_url)
features = scorer.encode_image(image)
save_embedding(entry.embedding_path, features)
processed += 1
logger.info("Embedding für %s gespeichert (%s)", entry.image_id, entry.embedding_path)
logger.info("Fertig. %d Embeddings erzeugt.", processed)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Berechnet CLIP-Embeddings für die Bilder aus images.csv")
parser.add_argument("--csv", default="images.csv", help="Pfad zur images.csv")
parser.add_argument(
"--model-name",
default="jinaai/jina-clip-v2",
help="Hugging-Face-Repository des gewünschten CLIP-Modells",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
entries = load_image_entries(Path(args.csv))
compute_embeddings(entries, model_name=args.model_name)
if __name__ == "__main__":
main()