| from __future__ import annotations | |
| import argparse | |
| import json | |
| import logging | |
| from pathlib import Path | |
| from typing import Iterable | |
| import numpy as np | |
| import requests | |
| import torch | |
| from PIL import Image | |
| from io import BytesIO | |
| from model import ClipScorer, ImageEntry, load_image_entries | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("precompute_embeddings") | |
| def download_image(url: str) -> Image.Image: | |
| headers = { | |
| "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/115.0.0.0 Safari/537.36" | |
| } | |
| response = requests.get(url, headers=headers) | |
| response.raise_for_status() | |
| return Image.open(BytesIO(response.content)) | |
| def save_embedding(path: Path, embedding: torch.Tensor) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| array = embedding.detach().cpu().numpy().astype(np.float32) | |
| suffix = path.suffix.lower() | |
| if suffix == ".json": | |
| data = { | |
| "embedding": array.tolist(), | |
| } | |
| with path.open("w", encoding="utf-8") as handle: | |
| json.dump(data, handle, ensure_ascii=False, separators=(",", ":")) | |
| else: | |
| np.save(path, array) | |
| def compute_embeddings(entries: Iterable[ImageEntry], model_name: str = "jinaai/jina-clip-v2") -> None: | |
| scorer = ClipScorer(model_name=model_name) | |
| processed = 0 | |
| for entry in entries: | |
| if entry.clip_model != model_name: | |
| logger.info( | |
| "Überspringe %s, da clip_model %s nicht zum Modell %s passt.", | |
| entry.image_id, | |
| entry.clip_model, | |
| model_name, | |
| ) | |
| continue | |
| if entry.embedding_path.exists(): | |
| logger.info("Embedding für %s existiert bereits (%s)", entry.image_id, entry.embedding_path) | |
| continue | |
| logger.info("Lade Bild %s", entry.image_id) | |
| image = download_image(entry.image_url) | |
| features = scorer.encode_image(image) | |
| save_embedding(entry.embedding_path, features) | |
| processed += 1 | |
| logger.info("Embedding für %s gespeichert (%s)", entry.image_id, entry.embedding_path) | |
| logger.info("Fertig. %d Embeddings erzeugt.", processed) | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Berechnet CLIP-Embeddings für die Bilder aus images.csv") | |
| parser.add_argument("--csv", default="images.csv", help="Pfad zur images.csv") | |
| parser.add_argument( | |
| "--model-name", | |
| default="jinaai/jina-clip-v2", | |
| help="Hugging-Face-Repository des gewünschten CLIP-Modells", | |
| ) | |
| return parser.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| entries = load_image_entries(Path(args.csv)) | |
| compute_embeddings(entries, model_name=args.model_name) | |
| if __name__ == "__main__": | |
| main() |