| """ |
| FastAPI server for the Contextual Similarity Engine. |
| |
| Endpoints: |
| /api/train/* β Train/adapt models (3 strategies) |
| /api/init β Load a model into the engine |
| /api/documents β Add documents to the corpus |
| /api/index/build β Build FAISS index |
| /api/query β Semantic search |
| /api/compare β Compare two texts |
| /api/analyze/* β Keyword analysis |
| /api/match β Keyword meaning matching |
| /api/eval/* β Evaluation metrics |
| /api/w2v/* β Word2Vec baseline comparison |
| /api/dataset/* β HuggingFace dataset loading (Epstein Files) |
| """ |
|
|
| import asyncio |
| import logging |
| import os |
| import time |
| import threading |
| from collections import deque |
| from pathlib import Path |
| from typing import Literal, Optional |
|
|
| from fastapi import FastAPI, HTTPException, Query, UploadFile, File, Form |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import StreamingResponse, FileResponse |
| from fastapi.staticfiles import StaticFiles |
| from pydantic import BaseModel, Field |
|
|
| from contextual_similarity import ContextualSimilarityEngine |
| from evaluation import Evaluator, GroundTruthEntry |
| from training import CorpusTrainer |
| from word2vec_baseline import Word2VecEngine |
| from data_loader import load_raw_dataset, load_raw_to_engine, import_chromadb_to_engine, get_dataset_info |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| class LogBuffer(logging.Handler): |
| """Thread-safe log handler that buffers recent messages for SSE streaming.""" |
|
|
| def __init__(self, max_lines: int = 500): |
| super().__init__() |
| self._buffer: deque[str] = deque(maxlen=max_lines) |
| self._lock = threading.Lock() |
| self._event = threading.Event() |
|
|
| def emit(self, record: logging.LogRecord) -> None: |
| msg = self.format(record) |
| with self._lock: |
| self._buffer.append(msg) |
| self._event.set() |
|
|
| def get_new_lines(self, after: int) -> tuple[list[str], int]: |
| """Return lines added after index `after`, and the new cursor.""" |
| with self._lock: |
| all_lines = list(self._buffer) |
| new_lines = all_lines[after:] if after < len(all_lines) else [] |
| return new_lines, len(all_lines) |
|
|
|
|
|
|
| log_buffer = LogBuffer() |
| log_buffer.setFormatter(logging.Formatter("%(asctime)s %(name)s %(message)s", datefmt="%H:%M:%S")) |
| log_buffer.setLevel(logging.INFO) |
| |
| logging.getLogger().addHandler(log_buffer) |
|
|
|
|
| |
| |
| |
|
|
| ALLOWED_MODELS = frozenset({ |
| "all-MiniLM-L6-v2", |
| "all-mpnet-base-v2", |
| "BAAI/bge-large-en-v1.5", |
| }) |
| ALLOWED_SOURCE_FILTERS = frozenset({"TEXT-", "IMAGES-"}) |
| MAX_UPLOAD_BYTES = 10 * 1024 * 1024 |
| BASE_DIR = Path(__file__).parent.resolve() |
|
|
|
|
| def _validate_model_name(name: str) -> str: |
| """Allow known HuggingFace models or local paths within project dir.""" |
| if name in ALLOWED_MODELS: |
| return name |
| |
| _validate_safe_path(name) |
| return name |
|
|
|
|
| def _validate_safe_path(path_str: str) -> str: |
| """Reject paths that escape the project directory.""" |
| resolved = Path(path_str).resolve() |
| if not resolved.is_relative_to(BASE_DIR): |
| raise HTTPException(400, "Path must be within the project directory.") |
| return path_str |
|
|
|
|
| def _to_native(obj): |
| """Recursively convert numpy types to native Python types for JSON serialization.""" |
| import numpy as np |
| if isinstance(obj, dict): |
| return {k: _to_native(v) for k, v in obj.items()} |
| if isinstance(obj, (list, tuple)): |
| return [_to_native(v) for v in obj] |
| if isinstance(obj, (np.integer,)): |
| return int(obj) |
| if isinstance(obj, (np.floating,)): |
| return float(obj) |
| if isinstance(obj, np.ndarray): |
| return obj.tolist() |
| return obj |
|
|
|
|
| |
| |
| |
|
|
| app = FastAPI( |
| title="Contextual Similarity API", |
| version="1.0.0", |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["http://localhost:5173", "http://localhost:3000", |
| "https://huggingface.co", "https://*.hf.space"], |
| allow_origin_regex=r"https://.*\.hf\.space", |
| allow_credentials=True, |
| allow_methods=["GET", "POST"], |
| allow_headers=["Content-Type", "Authorization"], |
| ) |
|
|
| |
| engine: Optional[ContextualSimilarityEngine] = None |
| evaluator: Optional[Evaluator] = None |
| w2v_engine: Optional[Word2VecEngine] = None |
|
|
| ENGINE_SAVE_DIR = Path(os.environ.get("ENGINE_STATE_DIR", str(BASE_DIR / "engine_state"))) |
| W2V_SAVE_DIR = Path(os.environ.get("W2V_STATE_DIR", str(BASE_DIR / "w2v_state"))) |
|
|
|
|
| @app.on_event("startup") |
| def _auto_restore(): |
| """Restore engine and W2V state from disk if previous saves exist.""" |
| global engine, evaluator, w2v_engine |
| if (ENGINE_SAVE_DIR / "meta.json").is_file(): |
| try: |
| engine = ContextualSimilarityEngine.load(str(ENGINE_SAVE_DIR)) |
| if engine.index is not None: |
| evaluator = Evaluator(engine) |
| logger.info("Auto-restored engine: %d chunks, %d docs", |
| len(engine.chunks), len(engine._doc_ids)) |
| except Exception: |
| logger.exception("Failed to auto-restore engine state β starting fresh") |
| if Word2VecEngine.has_saved_state(str(W2V_SAVE_DIR)): |
| try: |
| w2v_engine = Word2VecEngine.load(str(W2V_SAVE_DIR)) |
| logger.info("Auto-restored Word2Vec: %d sentences, %d vocab", |
| len(w2v_engine.sentences), len(w2v_engine.model.wv)) |
| except Exception: |
| logger.exception("Failed to auto-restore Word2Vec state β starting fresh") |
|
|
|
|
| @app.get("/api/logs/stream") |
| async def stream_logs(): |
| """SSE endpoint: streams server log lines in real-time.""" |
| async def event_generator(): |
| cursor = 0 |
| |
| lines, cursor = log_buffer.get_new_lines(cursor) |
| for line in lines[-20:]: |
| yield f"data: {line}\n\n" |
| cursor = max(cursor, 0) |
| while True: |
| await asyncio.sleep(0.5) |
| lines, cursor = log_buffer.get_new_lines(cursor) |
| if lines: |
| for line in lines: |
| yield f"data: {line}\n\n" |
|
|
| return StreamingResponse( |
| event_generator(), |
| media_type="text/event-stream", |
| headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, |
| ) |
|
|
|
|
| @app.get("/api/logs/poll") |
| def poll_logs(cursor: int = Query(default=0, ge=0)): |
| """Polling fallback for log streaming (works through HF Spaces proxy).""" |
| lines, new_cursor = log_buffer.get_new_lines(cursor) |
| return {"lines": lines, "cursor": new_cursor} |
|
|
|
|
| |
| |
| |
|
|
| class TrainRequest(BaseModel): |
| corpus_texts: list[str] = Field(max_length=10_000) |
| base_model: str = "all-MiniLM-L6-v2" |
| output_path: str = "./trained_model" |
| epochs: int = Field(default=5, ge=1, le=50) |
| batch_size: int = Field(default=16, ge=1, le=256) |
|
|
|
|
| class TrainKeywordsRequest(TrainRequest): |
| keyword_meanings: dict[str, str] |
|
|
|
|
| class TrainEvalRequest(BaseModel): |
| test_pairs: list[dict] = Field(max_length=1_000) |
| trained_model_path: str = "./trained_model" |
| base_model: str = "all-MiniLM-L6-v2" |
| corpus_texts: list[str] = Field(default=[], max_length=10_000) |
|
|
|
|
| class InitRequest(BaseModel): |
| model_name: str = "all-MiniLM-L6-v2" |
| chunk_size: int = Field(default=512, ge=50, le=8192) |
| chunk_overlap: int = Field(default=128, ge=0, le=4096) |
| batch_size: int = Field(default=64, ge=1, le=512) |
|
|
|
|
| class DocumentRequest(BaseModel): |
| doc_id: str = Field(max_length=200) |
| text: str = Field(max_length=10_000_000) |
|
|
|
|
| class QueryRequest(BaseModel): |
| text: str = Field(max_length=10_000) |
| top_k: int = Field(default=10, ge=1, le=100) |
|
|
|
|
| class CompareRequest(BaseModel): |
| text_a: str = Field(max_length=50_000) |
| text_b: str = Field(max_length=50_000) |
|
|
|
|
| class KeywordAnalysisRequest(BaseModel): |
| keyword: str = Field(max_length=200) |
| top_k: int = Field(default=10, ge=1, le=100) |
| cluster_threshold: float = Field(default=0.35, ge=0.01, le=1.0) |
|
|
|
|
| class BatchAnalysisRequest(BaseModel): |
| keywords: list[str] = Field(max_length=50) |
| top_k: int = Field(default=10, ge=1, le=100) |
| cluster_threshold: float = Field(default=0.35, ge=0.01, le=1.0) |
| compare_across: bool = True |
|
|
|
|
| class KeywordMatchRequest(BaseModel): |
| keyword: str = Field(max_length=200) |
| candidate_meanings: list[str] = Field(max_length=50) |
|
|
|
|
| class EvalDisambiguationRequest(BaseModel): |
| ground_truth: list[dict] = Field(max_length=10_000) |
| candidate_meanings: dict[str, list[str]] |
|
|
|
|
| class EvalRetrievalRequest(BaseModel): |
| queries: list[dict] = Field(max_length=1_000) |
| k_values: list[int] = Field(default=[1, 3, 5, 10], max_length=10) |
|
|
|
|
| class W2VInitRequest(BaseModel): |
| corpus_texts: list[str] = Field(max_length=10_000) |
| vector_size: int = Field(default=100, ge=50, le=500) |
| window: int = Field(default=5, ge=1, le=20) |
| epochs: int = Field(default=50, ge=1, le=200) |
|
|
|
|
| class W2VCompareRequest(BaseModel): |
| text_a: str = Field(max_length=50_000) |
| text_b: str = Field(max_length=50_000) |
|
|
|
|
| class W2VQueryRequest(BaseModel): |
| text: str = Field(max_length=10_000) |
| top_k: int = Field(default=10, ge=1, le=100) |
|
|
|
|
| class W2VWordRequest(BaseModel): |
| word: str = Field(max_length=200) |
| top_k: int = Field(default=10, ge=1, le=100) |
|
|
|
|
| class ContextAnalysisRequest(BaseModel): |
| keyword: str = Field(max_length=200) |
| cluster_threshold: float = Field(default=0.35, ge=0.01, le=1.0) |
| top_words: int = Field(default=8, ge=1, le=30) |
|
|
|
|
| class DatasetLoadRequest(BaseModel): |
| source: Literal["raw", "embeddings"] = "raw" |
| max_docs: int = Field(default=500, ge=1, le=100_000) |
| min_text_length: int = Field(default=100, ge=0, le=100_000) |
| source_filter: Optional[str] = None |
| build_index: bool = True |
|
|
|
|
| |
| |
| |
|
|
| def _run_training(req: TrainRequest, strategy: str, train_fn): |
| """Common wrapper for all training endpoints: validate, log, time, train.""" |
| _validate_model_name(req.base_model) |
| _validate_safe_path(req.output_path) |
| logger.info("Training (%s): model=%s, corpus=%d texts, epochs=%d, batch=%d", |
| strategy, req.base_model, len(req.corpus_texts), req.epochs, req.batch_size) |
| t0 = time.time() |
| trainer = CorpusTrainer(req.corpus_texts, req.base_model) |
| result = train_fn(trainer) |
| logger.info("Training (%s) complete in %.1fs β %s", strategy, time.time() - t0, req.output_path) |
| return result |
|
|
| @app.post("/api/train/unsupervised") |
| def train_unsupervised(req: TrainRequest): |
| """Soft-label domain adaptation. No labels needed.""" |
| return _run_training(req, "unsupervised", |
| lambda t: t.train_unsupervised(req.output_path, req.epochs, req.batch_size)) |
|
|
|
|
| @app.post("/api/train/contrastive") |
| def train_contrastive(req: TrainRequest): |
| """Contrastive: learns from corpus structure (adjacent sentences = similar).""" |
| return _run_training(req, "contrastive", |
| lambda t: t.train_contrastive(req.output_path, req.epochs, req.batch_size)) |
|
|
|
|
| @app.post("/api/train/keywords") |
| def train_keywords(req: TrainKeywordsRequest): |
| """Keyword-supervised: provide keywordβmeaning map, pairs auto-generated.""" |
| return _run_training(req, "keyword-supervised", |
| lambda t: t.train_with_keywords(req.keyword_meanings, req.output_path, req.epochs, req.batch_size)) |
|
|
|
|
| @app.post("/api/train/evaluate") |
| def train_evaluate(req: TrainEvalRequest): |
| """Compare base model vs trained model on test pairs.""" |
| _validate_model_name(req.base_model) |
| _validate_model_name(req.trained_model_path) |
| logger.info("Evaluating: base=%s vs trained=%s, %d test pairs", |
| req.base_model, req.trained_model_path, len(req.test_pairs)) |
| corpus = req.corpus_texts or ["placeholder text for initialization."] |
| trainer = CorpusTrainer(corpus, req.base_model) |
| test_pairs = [ |
| (p["text_a"], p["text_b"], p.get("expected", p.get("score", 0.5))) |
| for p in req.test_pairs |
| ] |
| result = trainer.evaluate_model(test_pairs, req.trained_model_path) |
| logger.info("Evaluation complete: %d pairs evaluated", len(test_pairs)) |
| return result |
|
|
|
|
| |
| |
| |
|
|
| @app.post("/api/init") |
| def init_engine(req: InitRequest): |
| """Initialize the similarity engine with a model (pretrained or trained).""" |
| _validate_model_name(req.model_name) |
| if req.chunk_overlap >= req.chunk_size: |
| raise HTTPException(400, "chunk_overlap must be less than chunk_size.") |
| global engine, evaluator |
| logger.info("Initializing engine: model=%s, chunk_size=%d, overlap=%d, batch=%d", |
| req.model_name, req.chunk_size, req.chunk_overlap, req.batch_size) |
| t0 = time.time() |
| engine = ContextualSimilarityEngine( |
| model_name=req.model_name, |
| chunk_size=req.chunk_size, |
| chunk_overlap=req.chunk_overlap, |
| batch_size=req.batch_size, |
| ) |
| evaluator = None |
| elapsed = round(time.time() - t0, 2) |
| logger.info("Engine initialized in %.2fs (model=%s)", elapsed, req.model_name) |
| return {"status": "ok", "model": req.model_name, "load_time_seconds": elapsed} |
|
|
|
|
| @app.post("/api/documents") |
| def add_document(req: DocumentRequest): |
| _ensure_engine() |
| logger.info("Adding document: id=%s, text_length=%d", req.doc_id, len(req.text)) |
| try: |
| chunks = engine.add_document(req.doc_id, req.text) |
| except ValueError as e: |
| raise HTTPException(status_code=400, detail=str(e)) |
| logger.info("Document '%s' added: %d chunks", req.doc_id, len(chunks)) |
| return { |
| "status": "ok", "doc_id": req.doc_id, "num_chunks": len(chunks), |
| "chunks_preview": [{"index": c.chunk_index, "text": c.text[:150]} for c in chunks[:5]], |
| } |
|
|
|
|
| @app.post("/api/documents/upload") |
| async def upload_document(file: UploadFile = File(...), doc_id: Optional[str] = Form(None)): |
| _ensure_engine() |
| contents = await file.read() |
| if len(contents) > MAX_UPLOAD_BYTES: |
| raise HTTPException(413, f"File too large. Maximum size is {MAX_UPLOAD_BYTES // (1024 * 1024)}MB.") |
| try: |
| text = contents.decode("utf-8") |
| except UnicodeDecodeError: |
| raise HTTPException(400, "File must be valid UTF-8 text.") |
| d_id = doc_id or Path(file.filename or "upload").stem |
| try: |
| chunks = engine.add_document(d_id, text) |
| except ValueError as e: |
| raise HTTPException(status_code=400, detail=str(e)) |
| return {"status": "ok", "doc_id": d_id, "num_chunks": len(chunks)} |
|
|
|
|
| @app.post("/api/index/build") |
| def build_index(): |
| _ensure_engine() |
| logger.info("Building FAISS index (%d documents in corpus)...", len(engine.corpus)) |
| t0 = time.time() |
| try: |
| engine.build_index(show_progress=True) |
| except RuntimeError as e: |
| raise HTTPException(status_code=400, detail=str(e)) |
| global evaluator |
| evaluator = Evaluator(engine) |
| elapsed = round(time.time() - t0, 2) |
| logger.info("FAISS index built: %d vectors (dim=%d) in %.2fs", |
| engine.index.ntotal, engine.embedding_dim, elapsed) |
| |
| try: |
| engine.save(str(ENGINE_SAVE_DIR)) |
| except Exception: |
| logger.warning("Auto-save after index build failed", exc_info=True) |
| return { |
| "status": "ok", "total_chunks": engine.index.ntotal, |
| "embedding_dim": engine.embedding_dim, "build_time_seconds": elapsed, |
| } |
|
|
|
|
| @app.post("/api/query") |
| def query_similar(req: QueryRequest): |
| _ensure_engine(); _ensure_index() |
| logger.info("Query: text='%s...' top_k=%d", req.text[:80], req.top_k) |
| results = engine.query(req.text, top_k=req.top_k) |
| return {"query": req.text, "results": [ |
| {"rank": r.rank, "score": round(r.score, 4), "doc_id": r.chunk.doc_id, |
| "chunk_index": r.chunk.chunk_index, "text": r.chunk.text} |
| for r in results |
| ]} |
|
|
|
|
| @app.post("/api/compare") |
| def compare_texts(req: CompareRequest): |
| _ensure_engine() |
| logger.info("Compare: text_a='%s...' vs text_b='%s...'", req.text_a[:60], req.text_b[:60]) |
| return {"text_a": req.text_a, "text_b": req.text_b, "similarity": round(engine.compare_texts(req.text_a, req.text_b), 4)} |
|
|
|
|
| @app.post("/api/analyze/keyword") |
| def analyze_keyword(req: KeywordAnalysisRequest): |
| _ensure_engine(); _ensure_index() |
| logger.info("Keyword analysis: keyword='%s', top_k=%d, threshold=%.2f", |
| req.keyword, req.top_k, req.cluster_threshold) |
| return _serialize_analysis(engine.analyze_keyword(req.keyword, req.top_k, req.cluster_threshold)) |
|
|
|
|
| @app.post("/api/analyze/similar-words") |
| def analyze_similar_words(req: W2VWordRequest): |
| """Find words that appear in similar contexts using transformer embeddings.""" |
| _ensure_engine(); _ensure_index() |
| logger.info("Similar words (transformer): word='%s', top_k=%d", req.word, req.top_k) |
| results = engine.similar_words(req.word, req.top_k) |
| return {"word": req.word, "similar": results} |
|
|
|
|
| @app.post("/api/analyze/context") |
| def analyze_context(req: ContextAnalysisRequest): |
| """Infer what a keyword likely means from its surrounding context words.""" |
| _ensure_engine(); _ensure_index() |
| logger.info("Context analysis: keyword='%s', threshold=%.2f, top_words=%d", |
| req.keyword, req.cluster_threshold, req.top_words) |
| return engine.infer_keyword_meanings( |
| req.keyword, |
| top_words=req.top_words, |
| cluster_threshold=req.cluster_threshold, |
| ) |
|
|
|
|
| @app.post("/api/analyze/batch") |
| def batch_analyze(req: BatchAnalysisRequest): |
| _ensure_engine(); _ensure_index() |
| logger.info("Batch analysis: %d keywords=%s, top_k=%d", |
| len(req.keywords), req.keywords[:5], req.top_k) |
| results = engine.batch_analyze_keywords(req.keywords, req.top_k, req.cluster_threshold, req.compare_across) |
| return {kw: _serialize_analysis(a) for kw, a in results.items()} |
|
|
|
|
| @app.post("/api/match") |
| def match_keyword(req: KeywordMatchRequest): |
| _ensure_engine(); _ensure_index() |
| results = engine.match_keyword_to_meaning(req.keyword, req.candidate_meanings) |
| return {"keyword": req.keyword, "candidate_meanings": req.candidate_meanings, "matches": [ |
| {"doc_id": r["chunk"].doc_id, "chunk_index": r["chunk"].chunk_index, "text": r["chunk"].text[:300], |
| "best_match": r["best_match"], "best_score": round(r["best_score"], 4), |
| "all_scores": {k: round(v, 4) for k, v in r["all_scores"].items()}} |
| for r in results |
| ]} |
|
|
|
|
| |
| |
| |
|
|
| @app.post("/api/eval/disambiguation") |
| def evaluate_disambiguation(req: EvalDisambiguationRequest): |
| _ensure_evaluator() |
| gt = [GroundTruthEntry(keyword=e["keyword"], text=e["text"], true_meaning=e["true_meaning"]) for e in req.ground_truth] |
| metrics = evaluator.evaluate_disambiguation(gt, req.candidate_meanings) |
| return _to_native({"metrics": [ |
| {"keyword": m.keyword, "accuracy": m.accuracy, "weighted_f1": m.weighted_f1, |
| "per_meaning_precision": m.per_meaning_precision, "per_meaning_recall": m.per_meaning_recall, |
| "per_meaning_f1": m.per_meaning_f1, "confusion_matrix": m.confusion, "total_samples": m.total_samples} |
| for m in metrics |
| ]}) |
|
|
|
|
| @app.post("/api/eval/retrieval") |
| def evaluate_retrieval(req: EvalRetrievalRequest): |
| _ensure_evaluator() |
| metrics = evaluator.evaluate_retrieval(req.queries, req.k_values) |
| return _to_native({"metrics": [ |
| {"query": m.query, "mrr": round(float(m.mrr), 4), |
| "precision_at_k": {str(k): round(float(v), 4) for k, v in m.precision_at_k.items()}, |
| "recall_at_k": {str(k): round(float(v), 4) for k, v in m.recall_at_k.items()}, |
| "ndcg_at_k": {str(k): round(float(v), 4) for k, v in m.ndcg_at_k.items()}, |
| "avg_similarity": round(float(m.avg_similarity), 4), "top_score": round(float(m.top_score), 4)} |
| for m in metrics |
| ]}) |
|
|
|
|
| @app.get("/api/eval/similarity-distribution") |
| def similarity_distribution(): |
| _ensure_evaluator() |
| return _to_native(evaluator.analyze_similarity_distribution()) |
|
|
|
|
| @app.get("/api/eval/report") |
| def get_eval_report(): |
| _ensure_evaluator() |
| return _to_native(evaluator.get_report().summary()) |
|
|
|
|
| @app.get("/api/stats") |
| def get_stats(): |
| _ensure_engine() |
| return engine.get_stats() |
|
|
|
|
| @app.get("/api/corpus/texts") |
| def get_corpus_texts(max_docs: int = Query(default=500, ge=1, le=10_000)): |
| """Return loaded document texts grouped by doc_id (for use as training corpus).""" |
| _ensure_engine() |
| |
| docs: dict[str, list[str]] = {} |
| for chunk in engine.chunks: |
| if chunk.doc_id not in docs: |
| docs[chunk.doc_id] = [] |
| docs[chunk.doc_id].append(chunk.text) |
| |
| result = [] |
| for doc_id in sorted(docs.keys()): |
| if len(result) >= max_docs: |
| break |
| result.append({"doc_id": doc_id, "text": "\n".join(docs[doc_id])}) |
| return {"documents": result, "count": len(result)} |
|
|
|
|
| @app.get("/api/documents/{doc_id}") |
| def get_document(doc_id: str): |
| """Return the full text of a document by reconstructing its chunks.""" |
| _ensure_engine() |
| chunks = [c for c in engine.chunks if c.doc_id == doc_id] |
| if not chunks: |
| raise HTTPException(404, f"Document '{doc_id}' not found.") |
| chunks.sort(key=lambda c: c.chunk_index) |
| full_text = "\n".join(c.text for c in chunks) |
| return {"doc_id": doc_id, "text": full_text, "num_chunks": len(chunks)} |
|
|
|
|
| @app.post("/api/engine/save") |
| def save_engine(): |
| """Save current engine state to disk for later restore.""" |
| _ensure_engine() |
| result = engine.save(str(ENGINE_SAVE_DIR)) |
| return {"status": "ok", **result} |
|
|
|
|
| @app.post("/api/engine/load") |
| def load_engine_state(): |
| """Load a previously saved engine state from disk.""" |
| global engine, evaluator |
| if not (ENGINE_SAVE_DIR / "meta.json").is_file(): |
| raise HTTPException(400, "No saved engine state found.") |
| engine = ContextualSimilarityEngine.load(str(ENGINE_SAVE_DIR)) |
| evaluator = Evaluator(engine) if engine.index is not None else None |
| return {"status": "ok", **engine.get_stats()} |
|
|
|
|
| @app.get("/api/engine/has-saved-state") |
| def has_saved_state(): |
| """Check if a saved engine state exists on disk.""" |
| exists = (ENGINE_SAVE_DIR / "meta.json").is_file() |
| return {"exists": exists} |
|
|
|
|
| |
| |
| |
|
|
| @app.post("/api/w2v/init") |
| def w2v_init(req: W2VInitRequest): |
| """Train Word2Vec on corpus for comparison.""" |
| global w2v_engine |
| logger.info("Word2Vec init: %d texts, vector_size=%d, window=%d, epochs=%d", |
| len(req.corpus_texts), req.vector_size, req.window, req.epochs) |
| t0 = time.time() |
| w2v_engine = Word2VecEngine(vector_size=req.vector_size, window=req.window, epochs=req.epochs) |
| for i, text in enumerate(req.corpus_texts): |
| w2v_engine.add_document(f"doc_{i}", text) |
| stats = w2v_engine.build_index() |
| elapsed = round(time.time() - t0, 2) |
| logger.info("Word2Vec ready: %s in %.2fs", stats, elapsed) |
| |
| try: |
| w2v_engine.save(str(W2V_SAVE_DIR)) |
| except Exception: |
| logger.warning("Auto-save W2V after init failed", exc_info=True) |
| return {**stats, "seconds": elapsed} |
|
|
|
|
| @app.post("/api/w2v/init-from-engine") |
| def w2v_init_from_engine( |
| vector_size: int = Query(default=100, ge=50, le=500), |
| window: int = Query(default=5, ge=1, le=20), |
| epochs: int = Query(default=50, ge=1, le=200), |
| ): |
| """Train Word2Vec directly from all documents already loaded in the engine. |
| |
| This avoids the round-trip through the frontend and uses ALL engine docs. |
| """ |
| global w2v_engine |
| _ensure_engine() |
| if not engine.chunks: |
| raise HTTPException(400, "No documents in the engine. Load a dataset first.") |
|
|
| |
| docs: dict[str, list[str]] = {} |
| for chunk in engine.chunks: |
| if chunk.doc_id not in docs: |
| docs[chunk.doc_id] = [] |
| docs[chunk.doc_id].append(chunk.text) |
|
|
| logger.info("Word2Vec init from engine: %d documents, vector_size=%d, window=%d, epochs=%d", |
| len(docs), vector_size, window, epochs) |
| t0 = time.time() |
| w2v_engine = Word2VecEngine(vector_size=vector_size, window=window, epochs=epochs) |
| for doc_id, chunks_list in docs.items(): |
| w2v_engine.add_document(doc_id, "\n".join(chunks_list)) |
| stats = w2v_engine.build_index() |
| elapsed = round(time.time() - t0, 2) |
| logger.info("Word2Vec ready: %s in %.2fs", stats, elapsed) |
| |
| try: |
| w2v_engine.save(str(W2V_SAVE_DIR)) |
| except Exception: |
| logger.warning("Auto-save W2V after init failed", exc_info=True) |
| return {**stats, "seconds": elapsed, "documents_used": len(docs)} |
|
|
|
|
| @app.post("/api/w2v/compare") |
| def w2v_compare(req: W2VCompareRequest): |
| _ensure_w2v() |
| return {"text_a": req.text_a, "text_b": req.text_b, |
| "similarity": round(w2v_engine.compare_texts(req.text_a, req.text_b), 4)} |
|
|
|
|
| @app.post("/api/w2v/query") |
| def w2v_query(req: W2VQueryRequest): |
| _ensure_w2v() |
| results = w2v_engine.query(req.text, top_k=req.top_k) |
| return {"query": req.text, "results": [ |
| {"rank": r.rank, "score": round(r.score, 4), "doc_id": r.doc_id, "text": r.text} |
| for r in results |
| ]} |
|
|
|
|
| @app.post("/api/w2v/similar-words") |
| def w2v_similar_words(req: W2VWordRequest): |
| _ensure_w2v() |
| similar = w2v_engine.most_similar_words(req.word, req.top_k) |
| return {"word": req.word, "similar": [{"word": w, "score": round(s, 4)} for w, s in similar]} |
|
|
|
|
| @app.get("/api/w2v/status") |
| def w2v_status(): |
| """Check if Word2Vec is loaded (from training or restored from disk).""" |
| if w2v_engine is not None and w2v_engine.model is not None: |
| return { |
| "ready": True, |
| "vocab_size": len(w2v_engine.model.wv), |
| "sentences": len(w2v_engine.sentences), |
| "vector_size": w2v_engine.vector_size, |
| } |
| has_saved = Word2VecEngine.has_saved_state(str(W2V_SAVE_DIR)) |
| return {"ready": False, "has_saved_state": has_saved} |
|
|
|
|
| @app.post("/api/w2v/reset") |
| def w2v_reset(): |
| """Delete saved Word2Vec state and clear the in-memory model.""" |
| global w2v_engine |
| w2v_engine = None |
| import shutil |
| if W2V_SAVE_DIR.is_dir(): |
| shutil.rmtree(str(W2V_SAVE_DIR)) |
| logger.info("Word2Vec state deleted from %s", W2V_SAVE_DIR) |
| return {"status": "ok", "message": "Word2Vec state cleared. You can retrain now."} |
|
|
|
|
| |
| |
| |
|
|
| @app.get("/api/dataset/info") |
| def dataset_info(): |
| """Get metadata about available HuggingFace datasets.""" |
| return get_dataset_info() |
|
|
|
|
| @app.post("/api/dataset/load") |
| def dataset_load(req: DatasetLoadRequest): |
| """Load Epstein Files dataset from HuggingFace into the engine.""" |
| global engine, evaluator |
| if engine is None: |
| logger.info("Engine not initialized β auto-initializing with default model...") |
| engine = ContextualSimilarityEngine( |
| model_name="all-MiniLM-L6-v2", |
| chunk_size=512, |
| chunk_overlap=128, |
| batch_size=64, |
| ) |
| evaluator = None |
| logger.info("Engine auto-initialized with all-MiniLM-L6-v2") |
| if req.source_filter and req.source_filter not in ALLOWED_SOURCE_FILTERS: |
| raise HTTPException(400, f"source_filter must be one of: {sorted(ALLOWED_SOURCE_FILTERS)}") |
| logger.info("Dataset load: source=%s, max_docs=%d, min_text=%d, filter=%s, build_index=%s", |
| req.source, req.max_docs, req.min_text_length, req.source_filter, req.build_index) |
| t0 = time.time() |
| try: |
| if req.source == "embeddings": |
| result = import_chromadb_to_engine(engine, max_chunks=req.max_docs * 10) |
| else: |
| result = load_raw_to_engine( |
| engine, |
| max_docs=req.max_docs, |
| min_text_length=req.min_text_length, |
| source_filter=req.source_filter, |
| build_index=req.build_index, |
| ) |
| logger.info("Dataset loaded in %.1fs", time.time() - t0) |
| |
| try: |
| engine.save(str(ENGINE_SAVE_DIR)) |
| except Exception: |
| logger.warning("Auto-save after dataset load failed", exc_info=True) |
| return result |
| except Exception: |
| logger.exception("Dataset load failed") |
| raise HTTPException(500, "Dataset load failed. Check server logs for details.") |
|
|
|
|
| @app.post("/api/dataset/preview") |
| def dataset_preview( |
| max_docs: int = Query(default=10, ge=1, le=100), |
| min_text_length: int = Query(default=100, ge=0, le=100_000), |
| source_filter: Optional[str] = Query(default=None), |
| ): |
| """Preview a few documents from the raw dataset without loading into engine.""" |
| if source_filter and source_filter not in ALLOWED_SOURCE_FILTERS: |
| raise HTTPException(400, f"source_filter must be one of: {sorted(ALLOWED_SOURCE_FILTERS)}") |
| try: |
| docs = load_raw_dataset( |
| max_docs=max_docs, |
| min_text_length=min_text_length, |
| source_filter=source_filter, |
| ) |
| return { |
| "count": len(docs), |
| "documents": [ |
| {"doc_id": d["doc_id"], "filename": d["filename"], |
| "text_preview": d["text"][:500], "text_length": len(d["text"])} |
| for d in docs |
| ], |
| } |
| except Exception: |
| logger.exception("Dataset preview failed") |
| raise HTTPException(500, "Dataset preview failed. Check server logs for details.") |
|
|
|
|
| |
| |
| |
|
|
| def _ensure_engine(): |
| if engine is None: |
| raise HTTPException(400, "Engine not initialized. POST /api/init first.") |
|
|
| def _ensure_index(): |
| if engine.index is None: |
| raise HTTPException(400, "Index not built. POST /api/index/build first.") |
|
|
| def _ensure_evaluator(): |
| global evaluator |
| if evaluator is None: |
| _ensure_engine(); _ensure_index() |
| evaluator = Evaluator(engine) |
|
|
| def _ensure_w2v(): |
| if w2v_engine is None: |
| raise HTTPException(400, "Word2Vec not initialized. POST /api/w2v/init first.") |
|
|
| def _serialize_analysis(analysis): |
| return { |
| "keyword": analysis.keyword, |
| "total_occurrences": analysis.total_occurrences, |
| "meaning_clusters": [{ |
| "cluster_id": c["cluster_id"], "size": c["size"], |
| "representative_text": c["representative_text"], |
| "contexts": [{"doc_id": ctx.chunk.doc_id, "chunk_index": ctx.chunk.chunk_index, |
| "text": ctx.chunk.text[:300], "highlight_positions": ctx.highlight_positions} |
| for ctx in c["contexts"]], |
| "similar_passages": [{"rank": s.rank, "score": round(s.score, 4), |
| "doc_id": s.chunk.doc_id, "text": s.chunk.text[:200]} |
| for s in c["similar_passages"]], |
| } for c in analysis.meaning_clusters], |
| "cross_keyword_similarities": {k: round(v, 4) for k, v in analysis.cross_keyword_similarities.items()}, |
| } |
|
|
| |
| |
| |
|
|
| _FRONTEND_DIR = BASE_DIR / "frontend" / "dist" |
|
|
| @app.get("/api/debug/frontend") |
| async def debug_frontend_files(): |
| """List all files in the frontend dist directory (for debugging deploys).""" |
| if not _FRONTEND_DIR.is_dir(): |
| return {"exists": False, "path": str(_FRONTEND_DIR)} |
| files = [] |
| for root, _dirs, filenames in os.walk(str(_FRONTEND_DIR)): |
| for f in filenames: |
| full = os.path.join(root, f) |
| rel = os.path.relpath(full, str(_FRONTEND_DIR)) |
| files.append({"path": rel, "size": os.path.getsize(full)}) |
| return {"exists": True, "path": str(_FRONTEND_DIR), "files": files} |
|
|
|
|
| if _FRONTEND_DIR.is_dir(): |
| @app.get("/{full_path:path}") |
| async def serve_frontend(full_path: str): |
| """Serve the React SPA β static files or index.html fallback.""" |
| if full_path: |
| file_path = (_FRONTEND_DIR / full_path).resolve() |
| if file_path.is_file() and file_path.is_relative_to(_FRONTEND_DIR): |
| return FileResponse(file_path) |
| return FileResponse(_FRONTEND_DIR / "index.html") |
|
|
| logger.info("Frontend serving enabled from %s", _FRONTEND_DIR) |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
| host = os.environ.get("HOST", "127.0.0.1") |
| port = int(os.environ.get("PORT", "8000")) |
| has_frontend = _FRONTEND_DIR.is_dir() |
| logger.info("=" * 60) |
| logger.info("Contextual Similarity API starting") |
| logger.info(" Server: http://%s:%d", host, port) |
| if has_frontend: |
| logger.info(" Frontend: http://%s:%d (built-in)", host, port) |
| else: |
| logger.info(" Frontend: http://localhost:5173 (dev server)") |
| logger.info(" API Docs: http://%s:%d/docs", host, port) |
| logger.info("=" * 60) |
| uvicorn.run(app, host=host, port=port) |
|
|