game / db.py
chrimons
Build Gradio CLIP scoring app
3c8c1fc
from __future__ import annotations
import os
import re
from contextlib import contextmanager
from datetime import datetime, timezone
from typing import List, Optional, Sequence
from sqlalchemy import (
Column,
DateTime,
Float,
ForeignKey,
Index,
Integer,
String,
Text,
create_engine,
func,
select,
)
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session, declarative_base, relationship, sessionmaker
Base = declarative_base()
USERNAME_REGEX = re.compile(r"^[A-Za-z0-9_.-]{3,20}$")
_engine: Optional[Engine] = None
_SessionLocal: Optional[sessionmaker] = None
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
username = Column(String(20), unique=True, nullable=False, index=True) # kanonisch (lowercase)
display_name = Column(String(20), nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
scores = relationship("Score", back_populates="user", cascade="all, delete-orphan")
class Score(Base):
__tablename__ = "scores"
id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
username = Column(String(20), nullable=False) # Originalschreibweise
canonical_username = Column(String(20), nullable=False, index=True)
image_id = Column(String(100), nullable=False, index=True)
score = Column(Integer, nullable=False)
similarity = Column(Float, nullable=False)
text = Column(Text, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False, index=True)
user = relationship("User", back_populates="scores")
__table_args__ = (
Index("ix_scores_global", score.desc(), created_at.asc(), id.asc()),
Index("ix_scores_image", "image_id", score.desc(), created_at.asc()),
Index("ix_scores_user", "canonical_username", created_at.desc(), id.desc()),
)
def configure_database(database_url: Optional[str] = None) -> Engine:
"""Initialisiert Engine und SessionFactory."""
global _engine, _SessionLocal
if database_url is None:
database_url = os.getenv("DATABASE_URL")
if not database_url:
raise RuntimeError("DATABASE_URL ist nicht gesetzt.")
_engine = create_engine(database_url, future=True, pool_pre_ping=True)
_SessionLocal = sessionmaker(bind=_engine, autoflush=False, expire_on_commit=False, future=True)
return _engine
def get_engine() -> Engine:
if _engine is None:
raise RuntimeError("Datenbank wurde noch nicht konfiguriert. Rufen Sie configure_database() auf.")
return _engine
def init_db() -> None:
engine = get_engine()
Base.metadata.create_all(engine)
@contextmanager
def session_scope() -> Session:
if _SessionLocal is None:
raise RuntimeError("SessionFactory nicht initialisiert. configure_database() zuerst aufrufen.")
session: Session = _SessionLocal()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
def normalize_username(username: str) -> str:
return username.strip().lower()
def validate_username(username: str) -> bool:
return bool(USERNAME_REGEX.match(username))
def ensure_user(session: Session, username_input: str) -> User:
"""Sucht oder erstellt einen Nutzer."""
normalized = normalize_username(username_input)
stmt = select(User).where(User.username == normalized)
user = session.execute(stmt).scalar_one_or_none()
if user:
return user
user = User(username=normalized, display_name=username_input.strip())
session.add(user)
session.flush()
return user
def create_score(
session: Session,
*,
user: User,
image_id: str,
score_value: int,
similarity: float,
text: str,
) -> Score:
entry = Score(
user_id=user.id,
username=user.display_name,
canonical_username=user.username,
image_id=image_id,
score=score_value,
similarity=similarity,
text=text,
)
session.add(entry)
session.flush()
return entry
def _format_timestamp(value: datetime | str | None) -> str:
if value is None:
return ""
if isinstance(value, str):
return value
if value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc).strftime("%Y-%m-%d %H:%M:%SZ")
def scores_to_rows(scores: Sequence[Score], include_rank: bool = True) -> List[List[object]]:
rows: List[List[object]] = []
for index, score in enumerate(scores, start=1):
base_row = [
score.username,
score.image_id,
score.score,
round(score.similarity, 4),
score.text,
_format_timestamp(score.created_at),
]
if include_rank:
rows.append([index, *base_row])
else:
rows.append(base_row)
return rows
def get_global_top(session: Session, limit: int = 50) -> List[Score]:
stmt = (
select(Score)
.order_by(Score.score.desc(), Score.created_at.asc(), Score.id.asc())
.limit(limit)
)
return list(session.scalars(stmt))
def get_image_top(session: Session, image_id: str, limit: int = 50) -> List[Score]:
stmt = (
select(Score)
.where(Score.image_id == image_id)
.order_by(Score.score.desc(), Score.created_at.asc(), Score.id.asc())
.limit(limit)
)
return list(session.scalars(stmt))
def get_user_recent(session: Session, canonical_username: str, limit: int = 50) -> List[Score]:
stmt = (
select(Score)
.where(Score.canonical_username == canonical_username)
.order_by(Score.created_at.desc(), Score.id.desc())
.limit(limit)
)
return list(session.scalars(stmt))
__all__ = [
"Base",
"Score",
"User",
"configure_database",
"create_score",
"ensure_user",
"get_engine",
"get_global_top",
"get_image_top",
"get_user_recent",
"init_db",
"normalize_username",
"scores_to_rows",
"session_scope",
"validate_username",
]