""" FastAPI server for the Contextual Similarity Engine. Endpoints: /api/train/* — Train/adapt models (3 strategies) /api/init — Load a model into the engine /api/documents — Add documents to the corpus /api/index/build — Build FAISS index /api/query — Semantic search /api/compare — Compare two texts /api/analyze/* — Keyword analysis /api/match — Keyword meaning matching /api/eval/* — Evaluation metrics /api/w2v/* — Word2Vec baseline comparison /api/dataset/* — HuggingFace dataset loading (Epstein Files) """ import asyncio import logging import os import time import threading from collections import deque from pathlib import Path from typing import Literal, Optional from fastapi import FastAPI, HTTPException, Query, UploadFile, File, Form from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, FileResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel, Field from contextual_similarity import ContextualSimilarityEngine from evaluation import Evaluator, GroundTruthEntry from training import CorpusTrainer from word2vec_baseline import Word2VecEngine from data_loader import load_raw_dataset, load_raw_to_engine, import_chromadb_to_engine, get_dataset_info logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ------------------------------------------------------------------ # # Log streaming buffer # ------------------------------------------------------------------ # class LogBuffer(logging.Handler): """Thread-safe log handler that buffers recent messages for SSE streaming.""" def __init__(self, max_lines: int = 500): super().__init__() self._buffer: deque[str] = deque(maxlen=max_lines) self._lock = threading.Lock() self._event = threading.Event() def emit(self, record: logging.LogRecord) -> None: msg = self.format(record) with self._lock: self._buffer.append(msg) self._event.set() def get_new_lines(self, after: int) -> tuple[list[str], int]: """Return lines added after index `after`, and the new cursor.""" with self._lock: all_lines = list(self._buffer) new_lines = all_lines[after:] if after < len(all_lines) else [] return new_lines, len(all_lines) log_buffer = LogBuffer() log_buffer.setFormatter(logging.Formatter("%(asctime)s %(name)s %(message)s", datefmt="%H:%M:%S")) log_buffer.setLevel(logging.INFO) # Attach to root logger so all modules' logs are captured logging.getLogger().addHandler(log_buffer) # ------------------------------------------------------------------ # # Security constants & validation helpers # ------------------------------------------------------------------ # ALLOWED_MODELS = frozenset({ "all-MiniLM-L6-v2", "all-mpnet-base-v2", "BAAI/bge-large-en-v1.5", }) ALLOWED_SOURCE_FILTERS = frozenset({"TEXT-", "IMAGES-"}) MAX_UPLOAD_BYTES = 10 * 1024 * 1024 # 10 MB BASE_DIR = Path(__file__).parent.resolve() def _validate_model_name(name: str) -> str: """Allow known HuggingFace models or local paths within project dir.""" if name in ALLOWED_MODELS: return name # Treat as a local model path — must be within the project directory _validate_safe_path(name) return name def _validate_safe_path(path_str: str) -> str: """Reject paths that escape the project directory.""" resolved = Path(path_str).resolve() if not resolved.is_relative_to(BASE_DIR): raise HTTPException(400, "Path must be within the project directory.") return path_str def _to_native(obj): """Recursively convert numpy types to native Python types for JSON serialization.""" import numpy as np if isinstance(obj, dict): return {k: _to_native(v) for k, v in obj.items()} if isinstance(obj, (list, tuple)): return [_to_native(v) for v in obj] if isinstance(obj, (np.integer,)): return int(obj) if isinstance(obj, (np.floating,)): return float(obj) if isinstance(obj, np.ndarray): return obj.tolist() return obj # ------------------------------------------------------------------ # # App setup # ------------------------------------------------------------------ # app = FastAPI( title="Contextual Similarity API", version="1.0.0", ) app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:5173", "http://localhost:3000", "https://huggingface.co", "https://*.hf.space"], allow_origin_regex=r"https://.*\.hf\.space", allow_credentials=True, allow_methods=["GET", "POST"], allow_headers=["Content-Type", "Authorization"], ) # Global instances engine: Optional[ContextualSimilarityEngine] = None evaluator: Optional[Evaluator] = None w2v_engine: Optional[Word2VecEngine] = None ENGINE_SAVE_DIR = Path(os.environ.get("ENGINE_STATE_DIR", str(BASE_DIR / "engine_state"))) W2V_SAVE_DIR = Path(os.environ.get("W2V_STATE_DIR", str(BASE_DIR / "w2v_state"))) @app.on_event("startup") def _auto_restore(): """Restore engine and W2V state from disk if previous saves exist.""" global engine, evaluator, w2v_engine if (ENGINE_SAVE_DIR / "meta.json").is_file(): try: engine = ContextualSimilarityEngine.load(str(ENGINE_SAVE_DIR)) if engine.index is not None: evaluator = Evaluator(engine) logger.info("Auto-restored engine: %d chunks, %d docs", len(engine.chunks), len(engine._doc_ids)) except Exception: logger.exception("Failed to auto-restore engine state — starting fresh") if Word2VecEngine.has_saved_state(str(W2V_SAVE_DIR)): try: w2v_engine = Word2VecEngine.load(str(W2V_SAVE_DIR)) logger.info("Auto-restored Word2Vec: %d sentences, %d vocab", len(w2v_engine.sentences), len(w2v_engine.model.wv)) except Exception: logger.exception("Failed to auto-restore Word2Vec state — starting fresh") @app.get("/api/logs/stream") async def stream_logs(): """SSE endpoint: streams server log lines in real-time.""" async def event_generator(): cursor = 0 # Send initial snapshot lines, cursor = log_buffer.get_new_lines(cursor) for line in lines[-20:]: # last 20 lines on connect yield f"data: {line}\n\n" cursor = max(cursor, 0) while True: await asyncio.sleep(0.5) lines, cursor = log_buffer.get_new_lines(cursor) if lines: for line in lines: yield f"data: {line}\n\n" return StreamingResponse( event_generator(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) @app.get("/api/logs/poll") def poll_logs(cursor: int = Query(default=0, ge=0)): """Polling fallback for log streaming (works through HF Spaces proxy).""" lines, new_cursor = log_buffer.get_new_lines(cursor) return {"lines": lines, "cursor": new_cursor} # ------------------------------------------------------------------ # # Request models (with input validation) # ------------------------------------------------------------------ # class TrainRequest(BaseModel): corpus_texts: list[str] = Field(max_length=10_000) base_model: str = "all-MiniLM-L6-v2" output_path: str = "./trained_model" epochs: int = Field(default=5, ge=1, le=50) batch_size: int = Field(default=16, ge=1, le=256) class TrainKeywordsRequest(TrainRequest): keyword_meanings: dict[str, str] class TrainEvalRequest(BaseModel): test_pairs: list[dict] = Field(max_length=1_000) trained_model_path: str = "./trained_model" base_model: str = "all-MiniLM-L6-v2" corpus_texts: list[str] = Field(default=[], max_length=10_000) class InitRequest(BaseModel): model_name: str = "all-MiniLM-L6-v2" chunk_size: int = Field(default=512, ge=50, le=8192) chunk_overlap: int = Field(default=128, ge=0, le=4096) batch_size: int = Field(default=64, ge=1, le=512) class DocumentRequest(BaseModel): doc_id: str = Field(max_length=200) text: str = Field(max_length=10_000_000) class QueryRequest(BaseModel): text: str = Field(max_length=10_000) top_k: int = Field(default=10, ge=1, le=100) class CompareRequest(BaseModel): text_a: str = Field(max_length=50_000) text_b: str = Field(max_length=50_000) class KeywordAnalysisRequest(BaseModel): keyword: str = Field(max_length=200) top_k: int = Field(default=10, ge=1, le=100) cluster_threshold: float = Field(default=0.35, ge=0.01, le=1.0) class BatchAnalysisRequest(BaseModel): keywords: list[str] = Field(max_length=50) top_k: int = Field(default=10, ge=1, le=100) cluster_threshold: float = Field(default=0.35, ge=0.01, le=1.0) compare_across: bool = True class KeywordMatchRequest(BaseModel): keyword: str = Field(max_length=200) candidate_meanings: list[str] = Field(max_length=50) class EvalDisambiguationRequest(BaseModel): ground_truth: list[dict] = Field(max_length=10_000) candidate_meanings: dict[str, list[str]] class EvalRetrievalRequest(BaseModel): queries: list[dict] = Field(max_length=1_000) k_values: list[int] = Field(default=[1, 3, 5, 10], max_length=10) class W2VInitRequest(BaseModel): corpus_texts: list[str] = Field(max_length=10_000) vector_size: int = Field(default=100, ge=50, le=500) window: int = Field(default=5, ge=1, le=20) epochs: int = Field(default=50, ge=1, le=200) class W2VCompareRequest(BaseModel): text_a: str = Field(max_length=50_000) text_b: str = Field(max_length=50_000) class W2VQueryRequest(BaseModel): text: str = Field(max_length=10_000) top_k: int = Field(default=10, ge=1, le=100) class W2VWordRequest(BaseModel): word: str = Field(max_length=200) top_k: int = Field(default=10, ge=1, le=100) class ContextAnalysisRequest(BaseModel): keyword: str = Field(max_length=200) cluster_threshold: float = Field(default=0.35, ge=0.01, le=1.0) top_words: int = Field(default=8, ge=1, le=30) class DatasetLoadRequest(BaseModel): source: Literal["raw", "embeddings"] = "raw" max_docs: int = Field(default=500, ge=1, le=100_000) min_text_length: int = Field(default=100, ge=0, le=100_000) source_filter: Optional[str] = None build_index: bool = True # ------------------------------------------------------------------ # # Training endpoints # ------------------------------------------------------------------ # def _run_training(req: TrainRequest, strategy: str, train_fn): """Common wrapper for all training endpoints: validate, log, time, train.""" _validate_model_name(req.base_model) _validate_safe_path(req.output_path) logger.info("Training (%s): model=%s, corpus=%d texts, epochs=%d, batch=%d", strategy, req.base_model, len(req.corpus_texts), req.epochs, req.batch_size) t0 = time.time() trainer = CorpusTrainer(req.corpus_texts, req.base_model) result = train_fn(trainer) logger.info("Training (%s) complete in %.1fs → %s", strategy, time.time() - t0, req.output_path) return result @app.post("/api/train/unsupervised") def train_unsupervised(req: TrainRequest): """Soft-label domain adaptation. No labels needed.""" return _run_training(req, "unsupervised", lambda t: t.train_unsupervised(req.output_path, req.epochs, req.batch_size)) @app.post("/api/train/contrastive") def train_contrastive(req: TrainRequest): """Contrastive: learns from corpus structure (adjacent sentences = similar).""" return _run_training(req, "contrastive", lambda t: t.train_contrastive(req.output_path, req.epochs, req.batch_size)) @app.post("/api/train/keywords") def train_keywords(req: TrainKeywordsRequest): """Keyword-supervised: provide keyword→meaning map, pairs auto-generated.""" return _run_training(req, "keyword-supervised", lambda t: t.train_with_keywords(req.keyword_meanings, req.output_path, req.epochs, req.batch_size)) @app.post("/api/train/evaluate") def train_evaluate(req: TrainEvalRequest): """Compare base model vs trained model on test pairs.""" _validate_model_name(req.base_model) _validate_model_name(req.trained_model_path) logger.info("Evaluating: base=%s vs trained=%s, %d test pairs", req.base_model, req.trained_model_path, len(req.test_pairs)) corpus = req.corpus_texts or ["placeholder text for initialization."] trainer = CorpusTrainer(corpus, req.base_model) test_pairs = [ (p["text_a"], p["text_b"], p.get("expected", p.get("score", 0.5))) for p in req.test_pairs ] result = trainer.evaluate_model(test_pairs, req.trained_model_path) logger.info("Evaluation complete: %d pairs evaluated", len(test_pairs)) return result # ------------------------------------------------------------------ # # Engine endpoints # ------------------------------------------------------------------ # @app.post("/api/init") def init_engine(req: InitRequest): """Initialize the similarity engine with a model (pretrained or trained).""" _validate_model_name(req.model_name) if req.chunk_overlap >= req.chunk_size: raise HTTPException(400, "chunk_overlap must be less than chunk_size.") global engine, evaluator logger.info("Initializing engine: model=%s, chunk_size=%d, overlap=%d, batch=%d", req.model_name, req.chunk_size, req.chunk_overlap, req.batch_size) t0 = time.time() engine = ContextualSimilarityEngine( model_name=req.model_name, chunk_size=req.chunk_size, chunk_overlap=req.chunk_overlap, batch_size=req.batch_size, ) evaluator = None elapsed = round(time.time() - t0, 2) logger.info("Engine initialized in %.2fs (model=%s)", elapsed, req.model_name) return {"status": "ok", "model": req.model_name, "load_time_seconds": elapsed} @app.post("/api/documents") def add_document(req: DocumentRequest): _ensure_engine() logger.info("Adding document: id=%s, text_length=%d", req.doc_id, len(req.text)) try: chunks = engine.add_document(req.doc_id, req.text) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) logger.info("Document '%s' added: %d chunks", req.doc_id, len(chunks)) return { "status": "ok", "doc_id": req.doc_id, "num_chunks": len(chunks), "chunks_preview": [{"index": c.chunk_index, "text": c.text[:150]} for c in chunks[:5]], } @app.post("/api/documents/upload") async def upload_document(file: UploadFile = File(...), doc_id: Optional[str] = Form(None)): _ensure_engine() contents = await file.read() if len(contents) > MAX_UPLOAD_BYTES: raise HTTPException(413, f"File too large. Maximum size is {MAX_UPLOAD_BYTES // (1024 * 1024)}MB.") try: text = contents.decode("utf-8") except UnicodeDecodeError: raise HTTPException(400, "File must be valid UTF-8 text.") d_id = doc_id or Path(file.filename or "upload").stem try: chunks = engine.add_document(d_id, text) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return {"status": "ok", "doc_id": d_id, "num_chunks": len(chunks)} @app.post("/api/index/build") def build_index(): _ensure_engine() logger.info("Building FAISS index (%d documents in corpus)...", len(engine.corpus)) t0 = time.time() try: engine.build_index(show_progress=True) except RuntimeError as e: raise HTTPException(status_code=400, detail=str(e)) global evaluator evaluator = Evaluator(engine) elapsed = round(time.time() - t0, 2) logger.info("FAISS index built: %d vectors (dim=%d) in %.2fs", engine.index.ntotal, engine.embedding_dim, elapsed) # Auto-save so data persists across restarts try: engine.save(str(ENGINE_SAVE_DIR)) except Exception: logger.warning("Auto-save after index build failed", exc_info=True) return { "status": "ok", "total_chunks": engine.index.ntotal, "embedding_dim": engine.embedding_dim, "build_time_seconds": elapsed, } @app.post("/api/query") def query_similar(req: QueryRequest): _ensure_engine(); _ensure_index() logger.info("Query: text='%s...' top_k=%d", req.text[:80], req.top_k) results = engine.query(req.text, top_k=req.top_k) return {"query": req.text, "results": [ {"rank": r.rank, "score": round(r.score, 4), "doc_id": r.chunk.doc_id, "chunk_index": r.chunk.chunk_index, "text": r.chunk.text} for r in results ]} @app.post("/api/compare") def compare_texts(req: CompareRequest): _ensure_engine() logger.info("Compare: text_a='%s...' vs text_b='%s...'", req.text_a[:60], req.text_b[:60]) return {"text_a": req.text_a, "text_b": req.text_b, "similarity": round(engine.compare_texts(req.text_a, req.text_b), 4)} @app.post("/api/analyze/keyword") def analyze_keyword(req: KeywordAnalysisRequest): _ensure_engine(); _ensure_index() logger.info("Keyword analysis: keyword='%s', top_k=%d, threshold=%.2f", req.keyword, req.top_k, req.cluster_threshold) return _serialize_analysis(engine.analyze_keyword(req.keyword, req.top_k, req.cluster_threshold)) @app.post("/api/analyze/similar-words") def analyze_similar_words(req: W2VWordRequest): """Find words that appear in similar contexts using transformer embeddings.""" _ensure_engine(); _ensure_index() logger.info("Similar words (transformer): word='%s', top_k=%d", req.word, req.top_k) results = engine.similar_words(req.word, req.top_k) return {"word": req.word, "similar": results} @app.post("/api/analyze/context") def analyze_context(req: ContextAnalysisRequest): """Infer what a keyword likely means from its surrounding context words.""" _ensure_engine(); _ensure_index() logger.info("Context analysis: keyword='%s', threshold=%.2f, top_words=%d", req.keyword, req.cluster_threshold, req.top_words) return engine.infer_keyword_meanings( req.keyword, top_words=req.top_words, cluster_threshold=req.cluster_threshold, ) @app.post("/api/analyze/batch") def batch_analyze(req: BatchAnalysisRequest): _ensure_engine(); _ensure_index() logger.info("Batch analysis: %d keywords=%s, top_k=%d", len(req.keywords), req.keywords[:5], req.top_k) results = engine.batch_analyze_keywords(req.keywords, req.top_k, req.cluster_threshold, req.compare_across) return {kw: _serialize_analysis(a) for kw, a in results.items()} @app.post("/api/match") def match_keyword(req: KeywordMatchRequest): _ensure_engine(); _ensure_index() results = engine.match_keyword_to_meaning(req.keyword, req.candidate_meanings) return {"keyword": req.keyword, "candidate_meanings": req.candidate_meanings, "matches": [ {"doc_id": r["chunk"].doc_id, "chunk_index": r["chunk"].chunk_index, "text": r["chunk"].text[:300], "best_match": r["best_match"], "best_score": round(r["best_score"], 4), "all_scores": {k: round(v, 4) for k, v in r["all_scores"].items()}} for r in results ]} # ------------------------------------------------------------------ # # Evaluation endpoints # ------------------------------------------------------------------ # @app.post("/api/eval/disambiguation") def evaluate_disambiguation(req: EvalDisambiguationRequest): _ensure_evaluator() gt = [GroundTruthEntry(keyword=e["keyword"], text=e["text"], true_meaning=e["true_meaning"]) for e in req.ground_truth] metrics = evaluator.evaluate_disambiguation(gt, req.candidate_meanings) return _to_native({"metrics": [ {"keyword": m.keyword, "accuracy": m.accuracy, "weighted_f1": m.weighted_f1, "per_meaning_precision": m.per_meaning_precision, "per_meaning_recall": m.per_meaning_recall, "per_meaning_f1": m.per_meaning_f1, "confusion_matrix": m.confusion, "total_samples": m.total_samples} for m in metrics ]}) @app.post("/api/eval/retrieval") def evaluate_retrieval(req: EvalRetrievalRequest): _ensure_evaluator() metrics = evaluator.evaluate_retrieval(req.queries, req.k_values) return _to_native({"metrics": [ {"query": m.query, "mrr": round(float(m.mrr), 4), "precision_at_k": {str(k): round(float(v), 4) for k, v in m.precision_at_k.items()}, "recall_at_k": {str(k): round(float(v), 4) for k, v in m.recall_at_k.items()}, "ndcg_at_k": {str(k): round(float(v), 4) for k, v in m.ndcg_at_k.items()}, "avg_similarity": round(float(m.avg_similarity), 4), "top_score": round(float(m.top_score), 4)} for m in metrics ]}) @app.get("/api/eval/similarity-distribution") def similarity_distribution(): _ensure_evaluator() return _to_native(evaluator.analyze_similarity_distribution()) @app.get("/api/eval/report") def get_eval_report(): _ensure_evaluator() return _to_native(evaluator.get_report().summary()) @app.get("/api/stats") def get_stats(): _ensure_engine() return engine.get_stats() @app.get("/api/corpus/texts") def get_corpus_texts(max_docs: int = Query(default=500, ge=1, le=10_000)): """Return loaded document texts grouped by doc_id (for use as training corpus).""" _ensure_engine() # Group chunks by doc_id docs: dict[str, list[str]] = {} for chunk in engine.chunks: if chunk.doc_id not in docs: docs[chunk.doc_id] = [] docs[chunk.doc_id].append(chunk.text) # Combine chunks per document result = [] for doc_id in sorted(docs.keys()): if len(result) >= max_docs: break result.append({"doc_id": doc_id, "text": "\n".join(docs[doc_id])}) return {"documents": result, "count": len(result)} @app.get("/api/documents/{doc_id}") def get_document(doc_id: str): """Return the full text of a document by reconstructing its chunks.""" _ensure_engine() chunks = [c for c in engine.chunks if c.doc_id == doc_id] if not chunks: raise HTTPException(404, f"Document '{doc_id}' not found.") chunks.sort(key=lambda c: c.chunk_index) full_text = "\n".join(c.text for c in chunks) return {"doc_id": doc_id, "text": full_text, "num_chunks": len(chunks)} @app.post("/api/engine/save") def save_engine(): """Save current engine state to disk for later restore.""" _ensure_engine() result = engine.save(str(ENGINE_SAVE_DIR)) return {"status": "ok", **result} @app.post("/api/engine/load") def load_engine_state(): """Load a previously saved engine state from disk.""" global engine, evaluator if not (ENGINE_SAVE_DIR / "meta.json").is_file(): raise HTTPException(400, "No saved engine state found.") engine = ContextualSimilarityEngine.load(str(ENGINE_SAVE_DIR)) evaluator = Evaluator(engine) if engine.index is not None else None return {"status": "ok", **engine.get_stats()} @app.get("/api/engine/has-saved-state") def has_saved_state(): """Check if a saved engine state exists on disk.""" exists = (ENGINE_SAVE_DIR / "meta.json").is_file() return {"exists": exists} # ------------------------------------------------------------------ # # Word2Vec baseline endpoints # ------------------------------------------------------------------ # @app.post("/api/w2v/init") def w2v_init(req: W2VInitRequest): """Train Word2Vec on corpus for comparison.""" global w2v_engine logger.info("Word2Vec init: %d texts, vector_size=%d, window=%d, epochs=%d", len(req.corpus_texts), req.vector_size, req.window, req.epochs) t0 = time.time() w2v_engine = Word2VecEngine(vector_size=req.vector_size, window=req.window, epochs=req.epochs) for i, text in enumerate(req.corpus_texts): w2v_engine.add_document(f"doc_{i}", text) stats = w2v_engine.build_index() elapsed = round(time.time() - t0, 2) logger.info("Word2Vec ready: %s in %.2fs", stats, elapsed) # Auto-save so data persists across restarts try: w2v_engine.save(str(W2V_SAVE_DIR)) except Exception: logger.warning("Auto-save W2V after init failed", exc_info=True) return {**stats, "seconds": elapsed} @app.post("/api/w2v/init-from-engine") def w2v_init_from_engine( vector_size: int = Query(default=100, ge=50, le=500), window: int = Query(default=5, ge=1, le=20), epochs: int = Query(default=50, ge=1, le=200), ): """Train Word2Vec directly from all documents already loaded in the engine. This avoids the round-trip through the frontend and uses ALL engine docs. """ global w2v_engine _ensure_engine() if not engine.chunks: raise HTTPException(400, "No documents in the engine. Load a dataset first.") # Group chunks by doc_id to reconstruct full documents docs: dict[str, list[str]] = {} for chunk in engine.chunks: if chunk.doc_id not in docs: docs[chunk.doc_id] = [] docs[chunk.doc_id].append(chunk.text) logger.info("Word2Vec init from engine: %d documents, vector_size=%d, window=%d, epochs=%d", len(docs), vector_size, window, epochs) t0 = time.time() w2v_engine = Word2VecEngine(vector_size=vector_size, window=window, epochs=epochs) for doc_id, chunks_list in docs.items(): w2v_engine.add_document(doc_id, "\n".join(chunks_list)) stats = w2v_engine.build_index() elapsed = round(time.time() - t0, 2) logger.info("Word2Vec ready: %s in %.2fs", stats, elapsed) # Auto-save try: w2v_engine.save(str(W2V_SAVE_DIR)) except Exception: logger.warning("Auto-save W2V after init failed", exc_info=True) return {**stats, "seconds": elapsed, "documents_used": len(docs)} @app.post("/api/w2v/compare") def w2v_compare(req: W2VCompareRequest): _ensure_w2v() return {"text_a": req.text_a, "text_b": req.text_b, "similarity": round(w2v_engine.compare_texts(req.text_a, req.text_b), 4)} @app.post("/api/w2v/query") def w2v_query(req: W2VQueryRequest): _ensure_w2v() results = w2v_engine.query(req.text, top_k=req.top_k) return {"query": req.text, "results": [ {"rank": r.rank, "score": round(r.score, 4), "doc_id": r.doc_id, "text": r.text} for r in results ]} @app.post("/api/w2v/similar-words") def w2v_similar_words(req: W2VWordRequest): _ensure_w2v() similar = w2v_engine.most_similar_words(req.word, req.top_k) return {"word": req.word, "similar": [{"word": w, "score": round(s, 4)} for w, s in similar]} @app.get("/api/w2v/status") def w2v_status(): """Check if Word2Vec is loaded (from training or restored from disk).""" if w2v_engine is not None and w2v_engine.model is not None: return { "ready": True, "vocab_size": len(w2v_engine.model.wv), "sentences": len(w2v_engine.sentences), "vector_size": w2v_engine.vector_size, } has_saved = Word2VecEngine.has_saved_state(str(W2V_SAVE_DIR)) return {"ready": False, "has_saved_state": has_saved} @app.post("/api/w2v/reset") def w2v_reset(): """Delete saved Word2Vec state and clear the in-memory model.""" global w2v_engine w2v_engine = None import shutil if W2V_SAVE_DIR.is_dir(): shutil.rmtree(str(W2V_SAVE_DIR)) logger.info("Word2Vec state deleted from %s", W2V_SAVE_DIR) return {"status": "ok", "message": "Word2Vec state cleared. You can retrain now."} # ------------------------------------------------------------------ # # Dataset endpoints (HuggingFace Epstein Files) # ------------------------------------------------------------------ # @app.get("/api/dataset/info") def dataset_info(): """Get metadata about available HuggingFace datasets.""" return get_dataset_info() @app.post("/api/dataset/load") def dataset_load(req: DatasetLoadRequest): """Load Epstein Files dataset from HuggingFace into the engine.""" global engine, evaluator if engine is None: logger.info("Engine not initialized — auto-initializing with default model...") engine = ContextualSimilarityEngine( model_name="all-MiniLM-L6-v2", chunk_size=512, chunk_overlap=128, batch_size=64, ) evaluator = None logger.info("Engine auto-initialized with all-MiniLM-L6-v2") if req.source_filter and req.source_filter not in ALLOWED_SOURCE_FILTERS: raise HTTPException(400, f"source_filter must be one of: {sorted(ALLOWED_SOURCE_FILTERS)}") logger.info("Dataset load: source=%s, max_docs=%d, min_text=%d, filter=%s, build_index=%s", req.source, req.max_docs, req.min_text_length, req.source_filter, req.build_index) t0 = time.time() try: if req.source == "embeddings": result = import_chromadb_to_engine(engine, max_chunks=req.max_docs * 10) else: result = load_raw_to_engine( engine, max_docs=req.max_docs, min_text_length=req.min_text_length, source_filter=req.source_filter, build_index=req.build_index, ) logger.info("Dataset loaded in %.1fs", time.time() - t0) # Auto-save so data persists across restarts try: engine.save(str(ENGINE_SAVE_DIR)) except Exception: logger.warning("Auto-save after dataset load failed", exc_info=True) return result except Exception: logger.exception("Dataset load failed") raise HTTPException(500, "Dataset load failed. Check server logs for details.") @app.post("/api/dataset/preview") def dataset_preview( max_docs: int = Query(default=10, ge=1, le=100), min_text_length: int = Query(default=100, ge=0, le=100_000), source_filter: Optional[str] = Query(default=None), ): """Preview a few documents from the raw dataset without loading into engine.""" if source_filter and source_filter not in ALLOWED_SOURCE_FILTERS: raise HTTPException(400, f"source_filter must be one of: {sorted(ALLOWED_SOURCE_FILTERS)}") try: docs = load_raw_dataset( max_docs=max_docs, min_text_length=min_text_length, source_filter=source_filter, ) return { "count": len(docs), "documents": [ {"doc_id": d["doc_id"], "filename": d["filename"], "text_preview": d["text"][:500], "text_length": len(d["text"])} for d in docs ], } except Exception: logger.exception("Dataset preview failed") raise HTTPException(500, "Dataset preview failed. Check server logs for details.") # ------------------------------------------------------------------ # # Helpers # ------------------------------------------------------------------ # def _ensure_engine(): if engine is None: raise HTTPException(400, "Engine not initialized. POST /api/init first.") def _ensure_index(): if engine.index is None: raise HTTPException(400, "Index not built. POST /api/index/build first.") def _ensure_evaluator(): global evaluator if evaluator is None: _ensure_engine(); _ensure_index() evaluator = Evaluator(engine) def _ensure_w2v(): if w2v_engine is None: raise HTTPException(400, "Word2Vec not initialized. POST /api/w2v/init first.") def _serialize_analysis(analysis): return { "keyword": analysis.keyword, "total_occurrences": analysis.total_occurrences, "meaning_clusters": [{ "cluster_id": c["cluster_id"], "size": c["size"], "representative_text": c["representative_text"], "contexts": [{"doc_id": ctx.chunk.doc_id, "chunk_index": ctx.chunk.chunk_index, "text": ctx.chunk.text[:300], "highlight_positions": ctx.highlight_positions} for ctx in c["contexts"]], "similar_passages": [{"rank": s.rank, "score": round(s.score, 4), "doc_id": s.chunk.doc_id, "text": s.chunk.text[:200]} for s in c["similar_passages"]], } for c in analysis.meaning_clusters], "cross_keyword_similarities": {k: round(v, 4) for k, v in analysis.cross_keyword_similarities.items()}, } # ------------------------------------------------------------------ # # Static frontend (production build served from /frontend/dist) # ------------------------------------------------------------------ # _FRONTEND_DIR = BASE_DIR / "frontend" / "dist" @app.get("/api/debug/frontend") async def debug_frontend_files(): """List all files in the frontend dist directory (for debugging deploys).""" if not _FRONTEND_DIR.is_dir(): return {"exists": False, "path": str(_FRONTEND_DIR)} files = [] for root, _dirs, filenames in os.walk(str(_FRONTEND_DIR)): for f in filenames: full = os.path.join(root, f) rel = os.path.relpath(full, str(_FRONTEND_DIR)) files.append({"path": rel, "size": os.path.getsize(full)}) return {"exists": True, "path": str(_FRONTEND_DIR), "files": files} if _FRONTEND_DIR.is_dir(): @app.get("/{full_path:path}") async def serve_frontend(full_path: str): """Serve the React SPA — static files or index.html fallback.""" if full_path: file_path = (_FRONTEND_DIR / full_path).resolve() if file_path.is_file() and file_path.is_relative_to(_FRONTEND_DIR): return FileResponse(file_path) return FileResponse(_FRONTEND_DIR / "index.html") logger.info("Frontend serving enabled from %s", _FRONTEND_DIR) if __name__ == "__main__": import uvicorn host = os.environ.get("HOST", "127.0.0.1") port = int(os.environ.get("PORT", "8000")) has_frontend = _FRONTEND_DIR.is_dir() logger.info("=" * 60) logger.info("Contextual Similarity API starting") logger.info(" Server: http://%s:%d", host, port) if has_frontend: logger.info(" Frontend: http://%s:%d (built-in)", host, port) else: logger.info(" Frontend: http://localhost:5173 (dev server)") logger.info(" API Docs: http://%s:%d/docs", host, port) logger.info("=" * 60) uvicorn.run(app, host=host, port=port)