File size: 2,848 Bytes
3c8c1fc
 
 
948475f
3c8c1fc
 
 
 
 
 
 
 
370be0d
3c8c1fc
 
 
 
 
 
 
 
370be0d
 
 
 
3c8c1fc
370be0d
3c8c1fc
 
 
 
 
948475f
 
 
 
 
 
 
 
 
 
 
 
 
3c8c1fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
948475f
 
3c8c1fc
 
 
 
 
 
 
 
948475f
 
 
 
 
3c8c1fc
 
 
 
 
 
948475f
3c8c1fc
 
 
948475f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from __future__ import annotations

import argparse
import json
import logging
from pathlib import Path
from typing import Iterable

import numpy as np
import requests
import torch
from PIL import Image
from io import BytesIO

from model import ClipScorer, ImageEntry, load_image_entries

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("precompute_embeddings")


def download_image(url: str) -> Image.Image:
    headers = {
        "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/115.0.0.0 Safari/537.36"
    }
    response = requests.get(url, headers=headers)
    response.raise_for_status()
    return Image.open(BytesIO(response.content))


def save_embedding(path: Path, embedding: torch.Tensor) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    array = embedding.detach().cpu().numpy().astype(np.float32)
    suffix = path.suffix.lower()
    if suffix == ".json":
        data = {
            "embedding": array.tolist(),
        }
        with path.open("w", encoding="utf-8") as handle:
            json.dump(data, handle, ensure_ascii=False, separators=(",", ":"))
    else:
        np.save(path, array)


def compute_embeddings(entries: Iterable[ImageEntry], model_name: str = "jinaai/jina-clip-v2") -> None:
    scorer = ClipScorer(model_name=model_name)
    processed = 0
    for entry in entries:
        if entry.clip_model != model_name:
            logger.info(
                "Überspringe %s, da clip_model %s nicht zum Modell %s passt.",
                entry.image_id,
                entry.clip_model,
                model_name,
            )
            continue
        if entry.embedding_path.exists():
            logger.info("Embedding für %s existiert bereits (%s)", entry.image_id, entry.embedding_path)
            continue
        logger.info("Lade Bild %s", entry.image_id)
        image = download_image(entry.image_url)
        features = scorer.encode_image(image)
        save_embedding(entry.embedding_path, features)
        processed += 1
        logger.info("Embedding für %s gespeichert (%s)", entry.image_id, entry.embedding_path)
    logger.info("Fertig. %d Embeddings erzeugt.", processed)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Berechnet CLIP-Embeddings für die Bilder aus images.csv")
    parser.add_argument("--csv", default="images.csv", help="Pfad zur images.csv")
    parser.add_argument(
        "--model-name",
        default="jinaai/jina-clip-v2",
        help="Hugging-Face-Repository des gewünschten CLIP-Modells",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    entries = load_image_entries(Path(args.csv))
    compute_embeddings(entries, model_name=args.model_name)


if __name__ == "__main__":
    main()