esfiles / server.py
Besjon Cifliku
feat: upadte logging in hf spaces
e29b232
"""
FastAPI server for the Contextual Similarity Engine.
Endpoints:
/api/train/* β€” Train/adapt models (3 strategies)
/api/init β€” Load a model into the engine
/api/documents β€” Add documents to the corpus
/api/index/build β€” Build FAISS index
/api/query β€” Semantic search
/api/compare β€” Compare two texts
/api/analyze/* β€” Keyword analysis
/api/match β€” Keyword meaning matching
/api/eval/* β€” Evaluation metrics
/api/w2v/* β€” Word2Vec baseline comparison
/api/dataset/* β€” HuggingFace dataset loading (Epstein Files)
"""
import asyncio
import logging
import os
import time
import threading
from collections import deque
from pathlib import Path
from typing import Literal, Optional
from fastapi import FastAPI, HTTPException, Query, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field
from contextual_similarity import ContextualSimilarityEngine
from evaluation import Evaluator, GroundTruthEntry
from training import CorpusTrainer
from word2vec_baseline import Word2VecEngine
from data_loader import load_raw_dataset, load_raw_to_engine, import_chromadb_to_engine, get_dataset_info
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ------------------------------------------------------------------ #
# Log streaming buffer
# ------------------------------------------------------------------ #
class LogBuffer(logging.Handler):
"""Thread-safe log handler that buffers recent messages for SSE streaming."""
def __init__(self, max_lines: int = 500):
super().__init__()
self._buffer: deque[str] = deque(maxlen=max_lines)
self._lock = threading.Lock()
self._event = threading.Event()
def emit(self, record: logging.LogRecord) -> None:
msg = self.format(record)
with self._lock:
self._buffer.append(msg)
self._event.set()
def get_new_lines(self, after: int) -> tuple[list[str], int]:
"""Return lines added after index `after`, and the new cursor."""
with self._lock:
all_lines = list(self._buffer)
new_lines = all_lines[after:] if after < len(all_lines) else []
return new_lines, len(all_lines)
log_buffer = LogBuffer()
log_buffer.setFormatter(logging.Formatter("%(asctime)s %(name)s %(message)s", datefmt="%H:%M:%S"))
log_buffer.setLevel(logging.INFO)
# Attach to root logger so all modules' logs are captured
logging.getLogger().addHandler(log_buffer)
# ------------------------------------------------------------------ #
# Security constants & validation helpers
# ------------------------------------------------------------------ #
ALLOWED_MODELS = frozenset({
"all-MiniLM-L6-v2",
"all-mpnet-base-v2",
"BAAI/bge-large-en-v1.5",
})
ALLOWED_SOURCE_FILTERS = frozenset({"TEXT-", "IMAGES-"})
MAX_UPLOAD_BYTES = 10 * 1024 * 1024 # 10 MB
BASE_DIR = Path(__file__).parent.resolve()
def _validate_model_name(name: str) -> str:
"""Allow known HuggingFace models or local paths within project dir."""
if name in ALLOWED_MODELS:
return name
# Treat as a local model path β€” must be within the project directory
_validate_safe_path(name)
return name
def _validate_safe_path(path_str: str) -> str:
"""Reject paths that escape the project directory."""
resolved = Path(path_str).resolve()
if not resolved.is_relative_to(BASE_DIR):
raise HTTPException(400, "Path must be within the project directory.")
return path_str
def _to_native(obj):
"""Recursively convert numpy types to native Python types for JSON serialization."""
import numpy as np
if isinstance(obj, dict):
return {k: _to_native(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [_to_native(v) for v in obj]
if isinstance(obj, (np.integer,)):
return int(obj)
if isinstance(obj, (np.floating,)):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
return obj
# ------------------------------------------------------------------ #
# App setup
# ------------------------------------------------------------------ #
app = FastAPI(
title="Contextual Similarity API",
version="1.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:5173", "http://localhost:3000",
"https://huggingface.co", "https://*.hf.space"],
allow_origin_regex=r"https://.*\.hf\.space",
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["Content-Type", "Authorization"],
)
# Global instances
engine: Optional[ContextualSimilarityEngine] = None
evaluator: Optional[Evaluator] = None
w2v_engine: Optional[Word2VecEngine] = None
ENGINE_SAVE_DIR = Path(os.environ.get("ENGINE_STATE_DIR", str(BASE_DIR / "engine_state")))
W2V_SAVE_DIR = Path(os.environ.get("W2V_STATE_DIR", str(BASE_DIR / "w2v_state")))
@app.on_event("startup")
def _auto_restore():
"""Restore engine and W2V state from disk if previous saves exist."""
global engine, evaluator, w2v_engine
if (ENGINE_SAVE_DIR / "meta.json").is_file():
try:
engine = ContextualSimilarityEngine.load(str(ENGINE_SAVE_DIR))
if engine.index is not None:
evaluator = Evaluator(engine)
logger.info("Auto-restored engine: %d chunks, %d docs",
len(engine.chunks), len(engine._doc_ids))
except Exception:
logger.exception("Failed to auto-restore engine state β€” starting fresh")
if Word2VecEngine.has_saved_state(str(W2V_SAVE_DIR)):
try:
w2v_engine = Word2VecEngine.load(str(W2V_SAVE_DIR))
logger.info("Auto-restored Word2Vec: %d sentences, %d vocab",
len(w2v_engine.sentences), len(w2v_engine.model.wv))
except Exception:
logger.exception("Failed to auto-restore Word2Vec state β€” starting fresh")
@app.get("/api/logs/stream")
async def stream_logs():
"""SSE endpoint: streams server log lines in real-time."""
async def event_generator():
cursor = 0
# Send initial snapshot
lines, cursor = log_buffer.get_new_lines(cursor)
for line in lines[-20:]: # last 20 lines on connect
yield f"data: {line}\n\n"
cursor = max(cursor, 0)
while True:
await asyncio.sleep(0.5)
lines, cursor = log_buffer.get_new_lines(cursor)
if lines:
for line in lines:
yield f"data: {line}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
@app.get("/api/logs/poll")
def poll_logs(cursor: int = Query(default=0, ge=0)):
"""Polling fallback for log streaming (works through HF Spaces proxy)."""
lines, new_cursor = log_buffer.get_new_lines(cursor)
return {"lines": lines, "cursor": new_cursor}
# ------------------------------------------------------------------ #
# Request models (with input validation)
# ------------------------------------------------------------------ #
class TrainRequest(BaseModel):
corpus_texts: list[str] = Field(max_length=10_000)
base_model: str = "all-MiniLM-L6-v2"
output_path: str = "./trained_model"
epochs: int = Field(default=5, ge=1, le=50)
batch_size: int = Field(default=16, ge=1, le=256)
class TrainKeywordsRequest(TrainRequest):
keyword_meanings: dict[str, str]
class TrainEvalRequest(BaseModel):
test_pairs: list[dict] = Field(max_length=1_000)
trained_model_path: str = "./trained_model"
base_model: str = "all-MiniLM-L6-v2"
corpus_texts: list[str] = Field(default=[], max_length=10_000)
class InitRequest(BaseModel):
model_name: str = "all-MiniLM-L6-v2"
chunk_size: int = Field(default=512, ge=50, le=8192)
chunk_overlap: int = Field(default=128, ge=0, le=4096)
batch_size: int = Field(default=64, ge=1, le=512)
class DocumentRequest(BaseModel):
doc_id: str = Field(max_length=200)
text: str = Field(max_length=10_000_000)
class QueryRequest(BaseModel):
text: str = Field(max_length=10_000)
top_k: int = Field(default=10, ge=1, le=100)
class CompareRequest(BaseModel):
text_a: str = Field(max_length=50_000)
text_b: str = Field(max_length=50_000)
class KeywordAnalysisRequest(BaseModel):
keyword: str = Field(max_length=200)
top_k: int = Field(default=10, ge=1, le=100)
cluster_threshold: float = Field(default=0.35, ge=0.01, le=1.0)
class BatchAnalysisRequest(BaseModel):
keywords: list[str] = Field(max_length=50)
top_k: int = Field(default=10, ge=1, le=100)
cluster_threshold: float = Field(default=0.35, ge=0.01, le=1.0)
compare_across: bool = True
class KeywordMatchRequest(BaseModel):
keyword: str = Field(max_length=200)
candidate_meanings: list[str] = Field(max_length=50)
class EvalDisambiguationRequest(BaseModel):
ground_truth: list[dict] = Field(max_length=10_000)
candidate_meanings: dict[str, list[str]]
class EvalRetrievalRequest(BaseModel):
queries: list[dict] = Field(max_length=1_000)
k_values: list[int] = Field(default=[1, 3, 5, 10], max_length=10)
class W2VInitRequest(BaseModel):
corpus_texts: list[str] = Field(max_length=10_000)
vector_size: int = Field(default=100, ge=50, le=500)
window: int = Field(default=5, ge=1, le=20)
epochs: int = Field(default=50, ge=1, le=200)
class W2VCompareRequest(BaseModel):
text_a: str = Field(max_length=50_000)
text_b: str = Field(max_length=50_000)
class W2VQueryRequest(BaseModel):
text: str = Field(max_length=10_000)
top_k: int = Field(default=10, ge=1, le=100)
class W2VWordRequest(BaseModel):
word: str = Field(max_length=200)
top_k: int = Field(default=10, ge=1, le=100)
class ContextAnalysisRequest(BaseModel):
keyword: str = Field(max_length=200)
cluster_threshold: float = Field(default=0.35, ge=0.01, le=1.0)
top_words: int = Field(default=8, ge=1, le=30)
class DatasetLoadRequest(BaseModel):
source: Literal["raw", "embeddings"] = "raw"
max_docs: int = Field(default=500, ge=1, le=100_000)
min_text_length: int = Field(default=100, ge=0, le=100_000)
source_filter: Optional[str] = None
build_index: bool = True
# ------------------------------------------------------------------ #
# Training endpoints
# ------------------------------------------------------------------ #
def _run_training(req: TrainRequest, strategy: str, train_fn):
"""Common wrapper for all training endpoints: validate, log, time, train."""
_validate_model_name(req.base_model)
_validate_safe_path(req.output_path)
logger.info("Training (%s): model=%s, corpus=%d texts, epochs=%d, batch=%d",
strategy, req.base_model, len(req.corpus_texts), req.epochs, req.batch_size)
t0 = time.time()
trainer = CorpusTrainer(req.corpus_texts, req.base_model)
result = train_fn(trainer)
logger.info("Training (%s) complete in %.1fs β†’ %s", strategy, time.time() - t0, req.output_path)
return result
@app.post("/api/train/unsupervised")
def train_unsupervised(req: TrainRequest):
"""Soft-label domain adaptation. No labels needed."""
return _run_training(req, "unsupervised",
lambda t: t.train_unsupervised(req.output_path, req.epochs, req.batch_size))
@app.post("/api/train/contrastive")
def train_contrastive(req: TrainRequest):
"""Contrastive: learns from corpus structure (adjacent sentences = similar)."""
return _run_training(req, "contrastive",
lambda t: t.train_contrastive(req.output_path, req.epochs, req.batch_size))
@app.post("/api/train/keywords")
def train_keywords(req: TrainKeywordsRequest):
"""Keyword-supervised: provide keyword→meaning map, pairs auto-generated."""
return _run_training(req, "keyword-supervised",
lambda t: t.train_with_keywords(req.keyword_meanings, req.output_path, req.epochs, req.batch_size))
@app.post("/api/train/evaluate")
def train_evaluate(req: TrainEvalRequest):
"""Compare base model vs trained model on test pairs."""
_validate_model_name(req.base_model)
_validate_model_name(req.trained_model_path)
logger.info("Evaluating: base=%s vs trained=%s, %d test pairs",
req.base_model, req.trained_model_path, len(req.test_pairs))
corpus = req.corpus_texts or ["placeholder text for initialization."]
trainer = CorpusTrainer(corpus, req.base_model)
test_pairs = [
(p["text_a"], p["text_b"], p.get("expected", p.get("score", 0.5)))
for p in req.test_pairs
]
result = trainer.evaluate_model(test_pairs, req.trained_model_path)
logger.info("Evaluation complete: %d pairs evaluated", len(test_pairs))
return result
# ------------------------------------------------------------------ #
# Engine endpoints
# ------------------------------------------------------------------ #
@app.post("/api/init")
def init_engine(req: InitRequest):
"""Initialize the similarity engine with a model (pretrained or trained)."""
_validate_model_name(req.model_name)
if req.chunk_overlap >= req.chunk_size:
raise HTTPException(400, "chunk_overlap must be less than chunk_size.")
global engine, evaluator
logger.info("Initializing engine: model=%s, chunk_size=%d, overlap=%d, batch=%d",
req.model_name, req.chunk_size, req.chunk_overlap, req.batch_size)
t0 = time.time()
engine = ContextualSimilarityEngine(
model_name=req.model_name,
chunk_size=req.chunk_size,
chunk_overlap=req.chunk_overlap,
batch_size=req.batch_size,
)
evaluator = None
elapsed = round(time.time() - t0, 2)
logger.info("Engine initialized in %.2fs (model=%s)", elapsed, req.model_name)
return {"status": "ok", "model": req.model_name, "load_time_seconds": elapsed}
@app.post("/api/documents")
def add_document(req: DocumentRequest):
_ensure_engine()
logger.info("Adding document: id=%s, text_length=%d", req.doc_id, len(req.text))
try:
chunks = engine.add_document(req.doc_id, req.text)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
logger.info("Document '%s' added: %d chunks", req.doc_id, len(chunks))
return {
"status": "ok", "doc_id": req.doc_id, "num_chunks": len(chunks),
"chunks_preview": [{"index": c.chunk_index, "text": c.text[:150]} for c in chunks[:5]],
}
@app.post("/api/documents/upload")
async def upload_document(file: UploadFile = File(...), doc_id: Optional[str] = Form(None)):
_ensure_engine()
contents = await file.read()
if len(contents) > MAX_UPLOAD_BYTES:
raise HTTPException(413, f"File too large. Maximum size is {MAX_UPLOAD_BYTES // (1024 * 1024)}MB.")
try:
text = contents.decode("utf-8")
except UnicodeDecodeError:
raise HTTPException(400, "File must be valid UTF-8 text.")
d_id = doc_id or Path(file.filename or "upload").stem
try:
chunks = engine.add_document(d_id, text)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return {"status": "ok", "doc_id": d_id, "num_chunks": len(chunks)}
@app.post("/api/index/build")
def build_index():
_ensure_engine()
logger.info("Building FAISS index (%d documents in corpus)...", len(engine.corpus))
t0 = time.time()
try:
engine.build_index(show_progress=True)
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
global evaluator
evaluator = Evaluator(engine)
elapsed = round(time.time() - t0, 2)
logger.info("FAISS index built: %d vectors (dim=%d) in %.2fs",
engine.index.ntotal, engine.embedding_dim, elapsed)
# Auto-save so data persists across restarts
try:
engine.save(str(ENGINE_SAVE_DIR))
except Exception:
logger.warning("Auto-save after index build failed", exc_info=True)
return {
"status": "ok", "total_chunks": engine.index.ntotal,
"embedding_dim": engine.embedding_dim, "build_time_seconds": elapsed,
}
@app.post("/api/query")
def query_similar(req: QueryRequest):
_ensure_engine(); _ensure_index()
logger.info("Query: text='%s...' top_k=%d", req.text[:80], req.top_k)
results = engine.query(req.text, top_k=req.top_k)
return {"query": req.text, "results": [
{"rank": r.rank, "score": round(r.score, 4), "doc_id": r.chunk.doc_id,
"chunk_index": r.chunk.chunk_index, "text": r.chunk.text}
for r in results
]}
@app.post("/api/compare")
def compare_texts(req: CompareRequest):
_ensure_engine()
logger.info("Compare: text_a='%s...' vs text_b='%s...'", req.text_a[:60], req.text_b[:60])
return {"text_a": req.text_a, "text_b": req.text_b, "similarity": round(engine.compare_texts(req.text_a, req.text_b), 4)}
@app.post("/api/analyze/keyword")
def analyze_keyword(req: KeywordAnalysisRequest):
_ensure_engine(); _ensure_index()
logger.info("Keyword analysis: keyword='%s', top_k=%d, threshold=%.2f",
req.keyword, req.top_k, req.cluster_threshold)
return _serialize_analysis(engine.analyze_keyword(req.keyword, req.top_k, req.cluster_threshold))
@app.post("/api/analyze/similar-words")
def analyze_similar_words(req: W2VWordRequest):
"""Find words that appear in similar contexts using transformer embeddings."""
_ensure_engine(); _ensure_index()
logger.info("Similar words (transformer): word='%s', top_k=%d", req.word, req.top_k)
results = engine.similar_words(req.word, req.top_k)
return {"word": req.word, "similar": results}
@app.post("/api/analyze/context")
def analyze_context(req: ContextAnalysisRequest):
"""Infer what a keyword likely means from its surrounding context words."""
_ensure_engine(); _ensure_index()
logger.info("Context analysis: keyword='%s', threshold=%.2f, top_words=%d",
req.keyword, req.cluster_threshold, req.top_words)
return engine.infer_keyword_meanings(
req.keyword,
top_words=req.top_words,
cluster_threshold=req.cluster_threshold,
)
@app.post("/api/analyze/batch")
def batch_analyze(req: BatchAnalysisRequest):
_ensure_engine(); _ensure_index()
logger.info("Batch analysis: %d keywords=%s, top_k=%d",
len(req.keywords), req.keywords[:5], req.top_k)
results = engine.batch_analyze_keywords(req.keywords, req.top_k, req.cluster_threshold, req.compare_across)
return {kw: _serialize_analysis(a) for kw, a in results.items()}
@app.post("/api/match")
def match_keyword(req: KeywordMatchRequest):
_ensure_engine(); _ensure_index()
results = engine.match_keyword_to_meaning(req.keyword, req.candidate_meanings)
return {"keyword": req.keyword, "candidate_meanings": req.candidate_meanings, "matches": [
{"doc_id": r["chunk"].doc_id, "chunk_index": r["chunk"].chunk_index, "text": r["chunk"].text[:300],
"best_match": r["best_match"], "best_score": round(r["best_score"], 4),
"all_scores": {k: round(v, 4) for k, v in r["all_scores"].items()}}
for r in results
]}
# ------------------------------------------------------------------ #
# Evaluation endpoints
# ------------------------------------------------------------------ #
@app.post("/api/eval/disambiguation")
def evaluate_disambiguation(req: EvalDisambiguationRequest):
_ensure_evaluator()
gt = [GroundTruthEntry(keyword=e["keyword"], text=e["text"], true_meaning=e["true_meaning"]) for e in req.ground_truth]
metrics = evaluator.evaluate_disambiguation(gt, req.candidate_meanings)
return _to_native({"metrics": [
{"keyword": m.keyword, "accuracy": m.accuracy, "weighted_f1": m.weighted_f1,
"per_meaning_precision": m.per_meaning_precision, "per_meaning_recall": m.per_meaning_recall,
"per_meaning_f1": m.per_meaning_f1, "confusion_matrix": m.confusion, "total_samples": m.total_samples}
for m in metrics
]})
@app.post("/api/eval/retrieval")
def evaluate_retrieval(req: EvalRetrievalRequest):
_ensure_evaluator()
metrics = evaluator.evaluate_retrieval(req.queries, req.k_values)
return _to_native({"metrics": [
{"query": m.query, "mrr": round(float(m.mrr), 4),
"precision_at_k": {str(k): round(float(v), 4) for k, v in m.precision_at_k.items()},
"recall_at_k": {str(k): round(float(v), 4) for k, v in m.recall_at_k.items()},
"ndcg_at_k": {str(k): round(float(v), 4) for k, v in m.ndcg_at_k.items()},
"avg_similarity": round(float(m.avg_similarity), 4), "top_score": round(float(m.top_score), 4)}
for m in metrics
]})
@app.get("/api/eval/similarity-distribution")
def similarity_distribution():
_ensure_evaluator()
return _to_native(evaluator.analyze_similarity_distribution())
@app.get("/api/eval/report")
def get_eval_report():
_ensure_evaluator()
return _to_native(evaluator.get_report().summary())
@app.get("/api/stats")
def get_stats():
_ensure_engine()
return engine.get_stats()
@app.get("/api/corpus/texts")
def get_corpus_texts(max_docs: int = Query(default=500, ge=1, le=10_000)):
"""Return loaded document texts grouped by doc_id (for use as training corpus)."""
_ensure_engine()
# Group chunks by doc_id
docs: dict[str, list[str]] = {}
for chunk in engine.chunks:
if chunk.doc_id not in docs:
docs[chunk.doc_id] = []
docs[chunk.doc_id].append(chunk.text)
# Combine chunks per document
result = []
for doc_id in sorted(docs.keys()):
if len(result) >= max_docs:
break
result.append({"doc_id": doc_id, "text": "\n".join(docs[doc_id])})
return {"documents": result, "count": len(result)}
@app.get("/api/documents/{doc_id}")
def get_document(doc_id: str):
"""Return the full text of a document by reconstructing its chunks."""
_ensure_engine()
chunks = [c for c in engine.chunks if c.doc_id == doc_id]
if not chunks:
raise HTTPException(404, f"Document '{doc_id}' not found.")
chunks.sort(key=lambda c: c.chunk_index)
full_text = "\n".join(c.text for c in chunks)
return {"doc_id": doc_id, "text": full_text, "num_chunks": len(chunks)}
@app.post("/api/engine/save")
def save_engine():
"""Save current engine state to disk for later restore."""
_ensure_engine()
result = engine.save(str(ENGINE_SAVE_DIR))
return {"status": "ok", **result}
@app.post("/api/engine/load")
def load_engine_state():
"""Load a previously saved engine state from disk."""
global engine, evaluator
if not (ENGINE_SAVE_DIR / "meta.json").is_file():
raise HTTPException(400, "No saved engine state found.")
engine = ContextualSimilarityEngine.load(str(ENGINE_SAVE_DIR))
evaluator = Evaluator(engine) if engine.index is not None else None
return {"status": "ok", **engine.get_stats()}
@app.get("/api/engine/has-saved-state")
def has_saved_state():
"""Check if a saved engine state exists on disk."""
exists = (ENGINE_SAVE_DIR / "meta.json").is_file()
return {"exists": exists}
# ------------------------------------------------------------------ #
# Word2Vec baseline endpoints
# ------------------------------------------------------------------ #
@app.post("/api/w2v/init")
def w2v_init(req: W2VInitRequest):
"""Train Word2Vec on corpus for comparison."""
global w2v_engine
logger.info("Word2Vec init: %d texts, vector_size=%d, window=%d, epochs=%d",
len(req.corpus_texts), req.vector_size, req.window, req.epochs)
t0 = time.time()
w2v_engine = Word2VecEngine(vector_size=req.vector_size, window=req.window, epochs=req.epochs)
for i, text in enumerate(req.corpus_texts):
w2v_engine.add_document(f"doc_{i}", text)
stats = w2v_engine.build_index()
elapsed = round(time.time() - t0, 2)
logger.info("Word2Vec ready: %s in %.2fs", stats, elapsed)
# Auto-save so data persists across restarts
try:
w2v_engine.save(str(W2V_SAVE_DIR))
except Exception:
logger.warning("Auto-save W2V after init failed", exc_info=True)
return {**stats, "seconds": elapsed}
@app.post("/api/w2v/init-from-engine")
def w2v_init_from_engine(
vector_size: int = Query(default=100, ge=50, le=500),
window: int = Query(default=5, ge=1, le=20),
epochs: int = Query(default=50, ge=1, le=200),
):
"""Train Word2Vec directly from all documents already loaded in the engine.
This avoids the round-trip through the frontend and uses ALL engine docs.
"""
global w2v_engine
_ensure_engine()
if not engine.chunks:
raise HTTPException(400, "No documents in the engine. Load a dataset first.")
# Group chunks by doc_id to reconstruct full documents
docs: dict[str, list[str]] = {}
for chunk in engine.chunks:
if chunk.doc_id not in docs:
docs[chunk.doc_id] = []
docs[chunk.doc_id].append(chunk.text)
logger.info("Word2Vec init from engine: %d documents, vector_size=%d, window=%d, epochs=%d",
len(docs), vector_size, window, epochs)
t0 = time.time()
w2v_engine = Word2VecEngine(vector_size=vector_size, window=window, epochs=epochs)
for doc_id, chunks_list in docs.items():
w2v_engine.add_document(doc_id, "\n".join(chunks_list))
stats = w2v_engine.build_index()
elapsed = round(time.time() - t0, 2)
logger.info("Word2Vec ready: %s in %.2fs", stats, elapsed)
# Auto-save
try:
w2v_engine.save(str(W2V_SAVE_DIR))
except Exception:
logger.warning("Auto-save W2V after init failed", exc_info=True)
return {**stats, "seconds": elapsed, "documents_used": len(docs)}
@app.post("/api/w2v/compare")
def w2v_compare(req: W2VCompareRequest):
_ensure_w2v()
return {"text_a": req.text_a, "text_b": req.text_b,
"similarity": round(w2v_engine.compare_texts(req.text_a, req.text_b), 4)}
@app.post("/api/w2v/query")
def w2v_query(req: W2VQueryRequest):
_ensure_w2v()
results = w2v_engine.query(req.text, top_k=req.top_k)
return {"query": req.text, "results": [
{"rank": r.rank, "score": round(r.score, 4), "doc_id": r.doc_id, "text": r.text}
for r in results
]}
@app.post("/api/w2v/similar-words")
def w2v_similar_words(req: W2VWordRequest):
_ensure_w2v()
similar = w2v_engine.most_similar_words(req.word, req.top_k)
return {"word": req.word, "similar": [{"word": w, "score": round(s, 4)} for w, s in similar]}
@app.get("/api/w2v/status")
def w2v_status():
"""Check if Word2Vec is loaded (from training or restored from disk)."""
if w2v_engine is not None and w2v_engine.model is not None:
return {
"ready": True,
"vocab_size": len(w2v_engine.model.wv),
"sentences": len(w2v_engine.sentences),
"vector_size": w2v_engine.vector_size,
}
has_saved = Word2VecEngine.has_saved_state(str(W2V_SAVE_DIR))
return {"ready": False, "has_saved_state": has_saved}
@app.post("/api/w2v/reset")
def w2v_reset():
"""Delete saved Word2Vec state and clear the in-memory model."""
global w2v_engine
w2v_engine = None
import shutil
if W2V_SAVE_DIR.is_dir():
shutil.rmtree(str(W2V_SAVE_DIR))
logger.info("Word2Vec state deleted from %s", W2V_SAVE_DIR)
return {"status": "ok", "message": "Word2Vec state cleared. You can retrain now."}
# ------------------------------------------------------------------ #
# Dataset endpoints (HuggingFace Epstein Files)
# ------------------------------------------------------------------ #
@app.get("/api/dataset/info")
def dataset_info():
"""Get metadata about available HuggingFace datasets."""
return get_dataset_info()
@app.post("/api/dataset/load")
def dataset_load(req: DatasetLoadRequest):
"""Load Epstein Files dataset from HuggingFace into the engine."""
global engine, evaluator
if engine is None:
logger.info("Engine not initialized β€” auto-initializing with default model...")
engine = ContextualSimilarityEngine(
model_name="all-MiniLM-L6-v2",
chunk_size=512,
chunk_overlap=128,
batch_size=64,
)
evaluator = None
logger.info("Engine auto-initialized with all-MiniLM-L6-v2")
if req.source_filter and req.source_filter not in ALLOWED_SOURCE_FILTERS:
raise HTTPException(400, f"source_filter must be one of: {sorted(ALLOWED_SOURCE_FILTERS)}")
logger.info("Dataset load: source=%s, max_docs=%d, min_text=%d, filter=%s, build_index=%s",
req.source, req.max_docs, req.min_text_length, req.source_filter, req.build_index)
t0 = time.time()
try:
if req.source == "embeddings":
result = import_chromadb_to_engine(engine, max_chunks=req.max_docs * 10)
else:
result = load_raw_to_engine(
engine,
max_docs=req.max_docs,
min_text_length=req.min_text_length,
source_filter=req.source_filter,
build_index=req.build_index,
)
logger.info("Dataset loaded in %.1fs", time.time() - t0)
# Auto-save so data persists across restarts
try:
engine.save(str(ENGINE_SAVE_DIR))
except Exception:
logger.warning("Auto-save after dataset load failed", exc_info=True)
return result
except Exception:
logger.exception("Dataset load failed")
raise HTTPException(500, "Dataset load failed. Check server logs for details.")
@app.post("/api/dataset/preview")
def dataset_preview(
max_docs: int = Query(default=10, ge=1, le=100),
min_text_length: int = Query(default=100, ge=0, le=100_000),
source_filter: Optional[str] = Query(default=None),
):
"""Preview a few documents from the raw dataset without loading into engine."""
if source_filter and source_filter not in ALLOWED_SOURCE_FILTERS:
raise HTTPException(400, f"source_filter must be one of: {sorted(ALLOWED_SOURCE_FILTERS)}")
try:
docs = load_raw_dataset(
max_docs=max_docs,
min_text_length=min_text_length,
source_filter=source_filter,
)
return {
"count": len(docs),
"documents": [
{"doc_id": d["doc_id"], "filename": d["filename"],
"text_preview": d["text"][:500], "text_length": len(d["text"])}
for d in docs
],
}
except Exception:
logger.exception("Dataset preview failed")
raise HTTPException(500, "Dataset preview failed. Check server logs for details.")
# ------------------------------------------------------------------ #
# Helpers
# ------------------------------------------------------------------ #
def _ensure_engine():
if engine is None:
raise HTTPException(400, "Engine not initialized. POST /api/init first.")
def _ensure_index():
if engine.index is None:
raise HTTPException(400, "Index not built. POST /api/index/build first.")
def _ensure_evaluator():
global evaluator
if evaluator is None:
_ensure_engine(); _ensure_index()
evaluator = Evaluator(engine)
def _ensure_w2v():
if w2v_engine is None:
raise HTTPException(400, "Word2Vec not initialized. POST /api/w2v/init first.")
def _serialize_analysis(analysis):
return {
"keyword": analysis.keyword,
"total_occurrences": analysis.total_occurrences,
"meaning_clusters": [{
"cluster_id": c["cluster_id"], "size": c["size"],
"representative_text": c["representative_text"],
"contexts": [{"doc_id": ctx.chunk.doc_id, "chunk_index": ctx.chunk.chunk_index,
"text": ctx.chunk.text[:300], "highlight_positions": ctx.highlight_positions}
for ctx in c["contexts"]],
"similar_passages": [{"rank": s.rank, "score": round(s.score, 4),
"doc_id": s.chunk.doc_id, "text": s.chunk.text[:200]}
for s in c["similar_passages"]],
} for c in analysis.meaning_clusters],
"cross_keyword_similarities": {k: round(v, 4) for k, v in analysis.cross_keyword_similarities.items()},
}
# ------------------------------------------------------------------ #
# Static frontend (production build served from /frontend/dist)
# ------------------------------------------------------------------ #
_FRONTEND_DIR = BASE_DIR / "frontend" / "dist"
@app.get("/api/debug/frontend")
async def debug_frontend_files():
"""List all files in the frontend dist directory (for debugging deploys)."""
if not _FRONTEND_DIR.is_dir():
return {"exists": False, "path": str(_FRONTEND_DIR)}
files = []
for root, _dirs, filenames in os.walk(str(_FRONTEND_DIR)):
for f in filenames:
full = os.path.join(root, f)
rel = os.path.relpath(full, str(_FRONTEND_DIR))
files.append({"path": rel, "size": os.path.getsize(full)})
return {"exists": True, "path": str(_FRONTEND_DIR), "files": files}
if _FRONTEND_DIR.is_dir():
@app.get("/{full_path:path}")
async def serve_frontend(full_path: str):
"""Serve the React SPA β€” static files or index.html fallback."""
if full_path:
file_path = (_FRONTEND_DIR / full_path).resolve()
if file_path.is_file() and file_path.is_relative_to(_FRONTEND_DIR):
return FileResponse(file_path)
return FileResponse(_FRONTEND_DIR / "index.html")
logger.info("Frontend serving enabled from %s", _FRONTEND_DIR)
if __name__ == "__main__":
import uvicorn
host = os.environ.get("HOST", "127.0.0.1")
port = int(os.environ.get("PORT", "8000"))
has_frontend = _FRONTEND_DIR.is_dir()
logger.info("=" * 60)
logger.info("Contextual Similarity API starting")
logger.info(" Server: http://%s:%d", host, port)
if has_frontend:
logger.info(" Frontend: http://%s:%d (built-in)", host, port)
else:
logger.info(" Frontend: http://localhost:5173 (dev server)")
logger.info(" API Docs: http://%s:%d/docs", host, port)
logger.info("=" * 60)
uvicorn.run(app, host=host, port=port)