# Fastapi endpoints defined here import json import os import re import time from typing import Any from dotenv import load_dotenv from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from huggingface_hub import InferenceClient from pydantic import BaseModel, Field from config_loader import cfg from vector_db import get_index_by_name, load_chunks_with_local_cache from retriever.retriever import HybridRetriever from retriever.generator import RAGGenerator from retriever.processor import ChunkProcessor from models.llama_3_8b import Llama3_8B from models.mistral_7b import Mistral_7b from models.qwen_2_5 import Qwen2_5 from models.deepseek_v3 import DeepSeek_V3 from models.tiny_aya import TinyAya #Added cacheing and time logging to track every stages time class PredictRequest(BaseModel): query: str = Field(..., min_length=1, description="User query text") model: str = Field(default="Llama-3-8B", description="Model name key") top_k: int = Field(default=10, ge=1, le=50) final_k: int = Field(default=5, ge=1, le=20) mode: str = Field(default="hybrid", description="semantic | bm25 | hybrid") rerank_strategy: str = Field(default="cross-encoder", description="cross-encoder | rrf | none") class PredictResponse(BaseModel): model: str answer: str contexts: list[str] retrieved_chunks: list[dict[str, Any]] class TitleRequest(BaseModel): query: str = Field(..., min_length=1, description="First user message") class TitleResponse(BaseModel): title: str source: str def _to_ndjson(payload: dict[str, Any]) -> str: return json.dumps(payload, ensure_ascii=False) + "\n" # simpliest possible implementation to determine chat title # is fallback incase hf generation fails. def _title_from_query(query: str) -> str: stop_words = { "a", "an", "and", "are", "as", "at", "be", "by", "can", "do", "for", "from", "how", "i", "in", "is", "it", "me", "my", "of", "on", "or", "please", "show", "tell", "that", "the", "this", "to", "we", "what", "when", "where", "which", "why", "with", "you", "your", } words = re.findall(r"[A-Za-z0-9][A-Za-z0-9\-_/+]*", query) if not words: return "New Chat" filtered: list[str] = [] for word in words: cleaned = word.strip("-_/+") if not cleaned: continue if cleaned.lower() in stop_words: continue filtered.append(cleaned) if len(filtered) >= 6: break chosen = filtered if filtered else words[:6] normalized = [w.capitalize() if w.islower() else w for w in chosen] title = " ".join(normalized).strip() return title[:80] if title else "New Chat" #actual code for title generation using hf model, uses a simple prompt to generate a concise title based on user query, with some formatting rules to ensure clean output. If generation fails or returns an empty title, falls back to rule-based method. # is called in the /predict/title endpoint def _clean_title_text(raw: str) -> str: text = (raw or "").strip() text = text.replace("\n", " ").replace("\r", " ") text = re.sub(r"^[\"'`\s]+|[\"'`\s]+$", "", text) text = re.sub(r"\s+", " ", text).strip() words = text.split() if len(words) > 8: text = " ".join(words[:8]) return text[:80] def _title_from_hf(query: str, client: InferenceClient, model_id: str) -> str | None: system_prompt = ( "You generate short chat titles. Return only a title, no punctuation at the end, no quotes." ) user_prompt = ( "Create a concise 3-7 word title for this user request:\n" f"{query}" ) response = client.chat_completion( model=model_id, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], max_tokens=24, temperature=0.3, ) if not response or not response.choices: return None raw_title = response.choices[0].message.content or "" title = _clean_title_text(raw_title) if not title or title.lower() == "new chat": return None return title def _parse_title_model_candidates() -> list[str]: raw = os.getenv( "TITLE_MODEL_IDS", "Qwen/Qwen2.5-1.5B-Instruct,CohereLabs/tiny-aya-global,meta-llama/Meta-Llama-3-8B-Instruct", ) models = [m.strip() for m in raw.split(",") if m.strip()] return models or ["meta-llama/Meta-Llama-3-8B-Instruct"] def _build_retrieved_chunks( contexts: list[str], chunk_lookup: dict[str, dict[str, Any]], ) -> list[dict[str, Any]]: if not contexts: return [] retrieved_chunks: list[dict[str, Any]] = [] for idx, text in enumerate(contexts, start=1): meta = chunk_lookup.get(text, {}) title = meta.get("title") or "Untitled" url = meta.get("url") or "" chunk_index = meta.get("chunk_index") retrieved_chunks.append( { "rank": idx, "text": text, "source_title": title, "source_url": url, "chunk_index": chunk_index, } ) return retrieved_chunks # Fastapi setup # Fastapi allows us to define python based endpoint # That is called from the react based frontend app = FastAPI(title="RAG-AS3 API", version="0.1.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) state: dict[str, Any] = {} def _build_models(hf_token: str) -> dict[str, Any]: return { "Llama-3-8B": Llama3_8B(token=hf_token), "Mistral-7B": Mistral_7b(token=hf_token), "Qwen-2.5": Qwen2_5(token=hf_token), "DeepSeek-V3": DeepSeek_V3(token=hf_token), "TinyAya": TinyAya(token=hf_token), } def _resolve_model(name: str, models: dict[str, Any]) -> tuple[str, Any]: aliases = { "llama": "Llama-3-8B", "mistral": "Mistral-7B", "qwen": "Qwen-2.5", "deepseek": "DeepSeek-V3", "tinyaya": "TinyAya", } model_key = aliases.get(name.lower(), name) if model_key not in models: allowed = ", ".join(models.keys()) raise HTTPException(status_code=400, detail=f"Unknown model '{name}'. Use one of: {allowed}") return model_key, models[model_key] # On startup most of the time is spent loading chunks from pinecone # This is done because bm25 needs the enture corpus in memory # we want to avoid loading it on every query, so loading it at startup is better # COuld improve this as not ideal to load entire corpus in memory # currently it wont scale well @app.on_event("startup") def startup_event() -> None: startup_start = time.perf_counter() dotenv_start = time.perf_counter() load_dotenv() dotenv_time = time.perf_counter() - dotenv_start env_start = time.perf_counter() hf_token = os.getenv("HF_TOKEN") pinecone_api_key = os.getenv("PINECONE_API_KEY") env_time = time.perf_counter() - env_start if not pinecone_api_key: raise RuntimeError("PINECONE_API_KEY not found in environment variables") if not hf_token: raise RuntimeError("HF_TOKEN not found in environment variables") index_name = "cbt-book-recursive" # Keep retrieval embedding model aligned with the one used at ingest time # to avoid Pinecone dimension mismatch errors (e.g., 384 vs 512). embed_model_name = cfg.processing.get("embedding_model", "all-MiniLM-L6-v2") project_root = os.path.dirname(os.path.abspath(__file__)) cache_dir = os.getenv("BM25_CACHE_DIR", os.path.join(project_root, ".cache")) force_cache_refresh = os.getenv("BM25_CACHE_REFRESH", "0").lower() in {"1", "true", "yes"} index_start = time.perf_counter() index = get_index_by_name( api_key=pinecone_api_key, index_name=index_name ) index_time = time.perf_counter() - index_start chunks_start = time.perf_counter() final_chunks, chunk_source = load_chunks_with_local_cache( index=index, index_name=index_name, cache_dir=cache_dir, batch_size=100, force_refresh=force_cache_refresh, ) chunk_load_time = time.perf_counter() - chunks_start if not final_chunks: raise RuntimeError("No chunks found in Pinecone metadata. Run indexing once before API mode.") processor_start = time.perf_counter() proc = ChunkProcessor(model_name=embed_model_name, verbose=False, load_hf_embeddings=False) processor_time = time.perf_counter() - processor_start retriever_start = time.perf_counter() retriever = HybridRetriever(final_chunks, proc.encoder, verbose=False) retriever_time = time.perf_counter() - retriever_start rag_start = time.perf_counter() rag_engine = RAGGenerator() rag_time = time.perf_counter() - rag_start models_start = time.perf_counter() models = _build_models(hf_token) models_time = time.perf_counter() - models_start state_start = time.perf_counter() chunk_lookup: dict[str, dict[str, Any]] = {} for chunk in final_chunks: metadata = chunk.get("metadata", {}) text = metadata.get("text") if not text or text in chunk_lookup: continue chunk_lookup[text] = { "title": metadata.get("title", "Untitled"), "url": metadata.get("url", ""), "chunk_index": metadata.get("chunk_index"), } state["index"] = index state["retriever"] = retriever state["rag_engine"] = rag_engine state["models"] = models state["chunk_lookup"] = chunk_lookup state["title_model_ids"] = _parse_title_model_candidates() state["title_client"] = InferenceClient(token=hf_token) state_time = time.perf_counter() - state_start startup_time = time.perf_counter() - startup_start print( f"API startup complete | chunks={len(final_chunks)} | " f"dotenv={dotenv_time:.3f}s | " f"env={env_time:.3f}s | " f"index={index_time:.3f}s | " f"cache_dir={cache_dir} | " f"force_cache_refresh={force_cache_refresh} | " f"chunk_source={chunk_source} | " f"chunk_load={chunk_load_time:.3f}s | " f"processor={processor_time:.3f}s | " f"retriever={retriever_time:.3f}s | " f"rag={rag_time:.3f}s | " f"models={models_time:.3f}s | " f"state={state_time:.3f}s | " f"total={startup_time:.3f}s" ) @app.get("/health") def health() -> dict[str, str]: ready = all(k in state for k in ("index", "retriever", "rag_engine", "models")) return {"status": "ok" if ready else "starting"} #title generation endpoint # is called only once when we create a new chat, after first prompt @app.post("/predict/title", response_model=TitleResponse) def suggest_title(payload: TitleRequest) -> TitleResponse: query = payload.query.strip() if not query: raise HTTPException(status_code=400, detail="Query cannot be empty") fallback_title = _title_from_query(query) title_client: InferenceClient | None = state.get("title_client") title_model_ids: list[str] = state.get("title_model_ids", _parse_title_model_candidates()) if title_client is not None: for title_model_id in title_model_ids: try: hf_title = _title_from_hf(query, title_client, title_model_id) if hf_title: return TitleResponse(title=hf_title, source=f"hf:{title_model_id}") except Exception as exc: err_text = str(exc) # Provider/model availability differs across HF accounts; skip unsupported models. if "model_not_supported" in err_text or "not supported by any provider" in err_text: continue print(f"Title generation model failed ({title_model_id}): {exc}") continue print("Title generation fallback triggered: no title model available/successful") return TitleResponse(title=fallback_title, source="rule-based") # Predict endpoint that takes a query and returns an answer along with contexts and metrics # is called from the frontend when user clicks submits # Also resolves model based on user selection @app.post("/predict", response_model=PredictResponse) def predict(payload: PredictRequest) -> PredictResponse: req_start = time.perf_counter() precheck_start = time.perf_counter() if not state: raise HTTPException(status_code=503, detail="Service not initialized yet") query = payload.query.strip() if not query: raise HTTPException(status_code=400, detail="Query cannot be empty") precheck_time = time.perf_counter() - precheck_start state_access_start = time.perf_counter() retriever: HybridRetriever = state["retriever"] index = state["index"] rag_engine: RAGGenerator = state["rag_engine"] models: dict[str, Any] = state["models"] chunk_lookup: dict[str, dict[str, Any]] = state.get("chunk_lookup", {}) state_access_time = time.perf_counter() - state_access_start model_resolve_start = time.perf_counter() model_name, model_instance = _resolve_model(payload.model, models) model_resolve_time = time.perf_counter() - model_resolve_start retrieval_start = time.perf_counter() contexts = retriever.search( query, index, mode=payload.mode, rerank_strategy=payload.rerank_strategy, use_mmr=True, top_k=payload.top_k, final_k=payload.final_k, verbose=False, ) retrieval_time = time.perf_counter() - retrieval_start if not contexts: raise HTTPException(status_code=404, detail="No context chunks retrieved for this query") inference_start = time.perf_counter() answer = rag_engine.get_answer(model_instance, query, contexts, temperature=0.1) inference_time = time.perf_counter() - inference_start mapping_start = time.perf_counter() retrieved_chunks = _build_retrieved_chunks( contexts=contexts, chunk_lookup=chunk_lookup, ) mapping_time = time.perf_counter() - mapping_start total_time = time.perf_counter() - req_start print( f"Predict timing | model={model_name} | mode={payload.mode} | " f"rerank={payload.rerank_strategy} | precheck={precheck_time:.3f}s | " f"state_access={state_access_time:.3f}s | model_resolve={model_resolve_time:.3f}s | " f"retrieval={retrieval_time:.3f}s | inference={inference_time:.3f}s | " f"context_map={mapping_time:.3f}s | total={total_time:.3f}s" ) return PredictResponse( model=model_name, answer=answer, contexts=contexts, retrieved_chunks=retrieved_chunks, ) # new endpoint for streaming response, allows frontend to render tokens as they come in instead of waiting for full answer @app.post("/predict/stream") def predict_stream(payload: PredictRequest) -> StreamingResponse: req_start = time.perf_counter() precheck_start = time.perf_counter() if not state: raise HTTPException(status_code=503, detail="Service not initialized yet") query = payload.query.strip() if not query: raise HTTPException(status_code=400, detail="Query cannot be empty") precheck_time = time.perf_counter() - precheck_start state_access_start = time.perf_counter() retriever: HybridRetriever = state["retriever"] index = state["index"] rag_engine: RAGGenerator = state["rag_engine"] models: dict[str, Any] = state["models"] chunk_lookup: dict[str, dict[str, Any]] = state.get("chunk_lookup", {}) state_access_time = time.perf_counter() - state_access_start model_resolve_start = time.perf_counter() model_name, model_instance = _resolve_model(payload.model, models) model_resolve_time = time.perf_counter() - model_resolve_start retrieval_start = time.perf_counter() contexts = retriever.search( query, index, mode=payload.mode, rerank_strategy=payload.rerank_strategy, use_mmr=True, top_k=payload.top_k, final_k=payload.final_k, verbose=False, ) retrieval_time = time.perf_counter() - retrieval_start if not contexts: raise HTTPException(status_code=404, detail="No context chunks retrieved for this query") def stream_events(): inference_start = time.perf_counter() answer_parts: list[str] = [] try: for token in rag_engine.get_answer_stream(model_instance, query, contexts, temperature=0.1): answer_parts.append(token) yield _to_ndjson({"type": "token", "token": token}) inference_time = time.perf_counter() - inference_start answer = "".join(answer_parts) retrieved_chunks = _build_retrieved_chunks( contexts=contexts, chunk_lookup=chunk_lookup, ) yield _to_ndjson( { "type": "done", "model": model_name, "answer": answer, "contexts": contexts, "retrieved_chunks": retrieved_chunks, } ) except Exception as exc: yield _to_ndjson({"type": "error", "message": f"Streaming failed: {exc}"}) return StreamingResponse( stream_events(), media_type="application/x-ndjson", headers={ "Cache-Control": "no-cache", "X-Accel-Buffering": "no", }, )