NLP-RAG / api.py
Qar-Raz's picture
added chunk details
23e3c5c
# Fastapi endpoints defined here
import json
import os
import re
import time
from typing import Any
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from huggingface_hub import InferenceClient
from pydantic import BaseModel, Field
from config_loader import cfg
from vector_db import get_index_by_name, load_chunks_with_local_cache
from retriever.retriever import HybridRetriever
from retriever.generator import RAGGenerator
from retriever.processor import ChunkProcessor
from models.llama_3_8b import Llama3_8B
from models.mistral_7b import Mistral_7b
from models.qwen_2_5 import Qwen2_5
from models.deepseek_v3 import DeepSeek_V3
from models.tiny_aya import TinyAya
#Added cacheing and time logging to track every stages time
class PredictRequest(BaseModel):
query: str = Field(..., min_length=1, description="User query text")
model: str = Field(default="Llama-3-8B", description="Model name key")
top_k: int = Field(default=10, ge=1, le=50)
final_k: int = Field(default=5, ge=1, le=20)
mode: str = Field(default="hybrid", description="semantic | bm25 | hybrid")
rerank_strategy: str = Field(default="cross-encoder", description="cross-encoder | rrf | none")
class PredictResponse(BaseModel):
model: str
answer: str
contexts: list[str]
retrieved_chunks: list[dict[str, Any]]
class TitleRequest(BaseModel):
query: str = Field(..., min_length=1, description="First user message")
class TitleResponse(BaseModel):
title: str
source: str
def _to_ndjson(payload: dict[str, Any]) -> str:
return json.dumps(payload, ensure_ascii=False) + "\n"
# simpliest possible implementation to determine chat title
# is fallback incase hf generation fails.
def _title_from_query(query: str) -> str:
stop_words = {
"a", "an", "and", "are", "as", "at", "be", "by", "can", "do", "for", "from", "how",
"i", "in", "is", "it", "me", "my", "of", "on", "or", "please", "show", "tell", "that",
"the", "this", "to", "we", "what", "when", "where", "which", "why", "with", "you", "your",
}
words = re.findall(r"[A-Za-z0-9][A-Za-z0-9\-_/+]*", query)
if not words:
return "New Chat"
filtered: list[str] = []
for word in words:
cleaned = word.strip("-_/+")
if not cleaned:
continue
if cleaned.lower() in stop_words:
continue
filtered.append(cleaned)
if len(filtered) >= 6:
break
chosen = filtered if filtered else words[:6]
normalized = [w.capitalize() if w.islower() else w for w in chosen]
title = " ".join(normalized).strip()
return title[:80] if title else "New Chat"
#actual code for title generation using hf model, uses a simple prompt to generate a concise title based on user query, with some formatting rules to ensure clean output. If generation fails or returns an empty title, falls back to rule-based method.
# is called in the /predict/title endpoint
def _clean_title_text(raw: str) -> str:
text = (raw or "").strip()
text = text.replace("\n", " ").replace("\r", " ")
text = re.sub(r"^[\"'`\s]+|[\"'`\s]+$", "", text)
text = re.sub(r"\s+", " ", text).strip()
words = text.split()
if len(words) > 8:
text = " ".join(words[:8])
return text[:80]
def _title_from_hf(query: str, client: InferenceClient, model_id: str) -> str | None:
system_prompt = (
"You generate short chat titles. Return only a title, no punctuation at the end, no quotes."
)
user_prompt = (
"Create a concise 3-7 word title for this user request:\n"
f"{query}"
)
response = client.chat_completion(
model=model_id,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
max_tokens=24,
temperature=0.3,
)
if not response or not response.choices:
return None
raw_title = response.choices[0].message.content or ""
title = _clean_title_text(raw_title)
if not title or title.lower() == "new chat":
return None
return title
def _parse_title_model_candidates() -> list[str]:
raw = os.getenv(
"TITLE_MODEL_IDS",
"Qwen/Qwen2.5-1.5B-Instruct,CohereLabs/tiny-aya-global,meta-llama/Meta-Llama-3-8B-Instruct",
)
models = [m.strip() for m in raw.split(",") if m.strip()]
return models or ["meta-llama/Meta-Llama-3-8B-Instruct"]
def _build_retrieved_chunks(
contexts: list[str],
chunk_lookup: dict[str, dict[str, Any]],
) -> list[dict[str, Any]]:
if not contexts:
return []
retrieved_chunks: list[dict[str, Any]] = []
for idx, text in enumerate(contexts, start=1):
meta = chunk_lookup.get(text, {})
title = meta.get("title") or "Untitled"
url = meta.get("url") or ""
chunk_index = meta.get("chunk_index")
retrieved_chunks.append(
{
"rank": idx,
"text": text,
"source_title": title,
"source_url": url,
"chunk_index": chunk_index,
}
)
return retrieved_chunks
# Fastapi setup
# Fastapi allows us to define python based endpoint
# That is called from the react based frontend
app = FastAPI(title="RAG-AS3 API", version="0.1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
state: dict[str, Any] = {}
def _build_models(hf_token: str) -> dict[str, Any]:
return {
"Llama-3-8B": Llama3_8B(token=hf_token),
"Mistral-7B": Mistral_7b(token=hf_token),
"Qwen-2.5": Qwen2_5(token=hf_token),
"DeepSeek-V3": DeepSeek_V3(token=hf_token),
"TinyAya": TinyAya(token=hf_token),
}
def _resolve_model(name: str, models: dict[str, Any]) -> tuple[str, Any]:
aliases = {
"llama": "Llama-3-8B",
"mistral": "Mistral-7B",
"qwen": "Qwen-2.5",
"deepseek": "DeepSeek-V3",
"tinyaya": "TinyAya",
}
model_key = aliases.get(name.lower(), name)
if model_key not in models:
allowed = ", ".join(models.keys())
raise HTTPException(status_code=400, detail=f"Unknown model '{name}'. Use one of: {allowed}")
return model_key, models[model_key]
# On startup most of the time is spent loading chunks from pinecone
# This is done because bm25 needs the enture corpus in memory
# we want to avoid loading it on every query, so loading it at startup is better
# COuld improve this as not ideal to load entire corpus in memory
# currently it wont scale well
@app.on_event("startup")
def startup_event() -> None:
startup_start = time.perf_counter()
dotenv_start = time.perf_counter()
load_dotenv()
dotenv_time = time.perf_counter() - dotenv_start
env_start = time.perf_counter()
hf_token = os.getenv("HF_TOKEN")
pinecone_api_key = os.getenv("PINECONE_API_KEY")
env_time = time.perf_counter() - env_start
if not pinecone_api_key:
raise RuntimeError("PINECONE_API_KEY not found in environment variables")
if not hf_token:
raise RuntimeError("HF_TOKEN not found in environment variables")
index_name = "cbt-book-recursive"
# Keep retrieval embedding model aligned with the one used at ingest time
# to avoid Pinecone dimension mismatch errors (e.g., 384 vs 512).
embed_model_name = cfg.processing.get("embedding_model", "all-MiniLM-L6-v2")
project_root = os.path.dirname(os.path.abspath(__file__))
cache_dir = os.getenv("BM25_CACHE_DIR", os.path.join(project_root, ".cache"))
force_cache_refresh = os.getenv("BM25_CACHE_REFRESH", "0").lower() in {"1", "true", "yes"}
index_start = time.perf_counter()
index = get_index_by_name(
api_key=pinecone_api_key,
index_name=index_name
)
index_time = time.perf_counter() - index_start
chunks_start = time.perf_counter()
final_chunks, chunk_source = load_chunks_with_local_cache(
index=index,
index_name=index_name,
cache_dir=cache_dir,
batch_size=100,
force_refresh=force_cache_refresh,
)
chunk_load_time = time.perf_counter() - chunks_start
if not final_chunks:
raise RuntimeError("No chunks found in Pinecone metadata. Run indexing once before API mode.")
processor_start = time.perf_counter()
proc = ChunkProcessor(model_name=embed_model_name, verbose=False, load_hf_embeddings=False)
processor_time = time.perf_counter() - processor_start
retriever_start = time.perf_counter()
retriever = HybridRetriever(final_chunks, proc.encoder, verbose=False)
retriever_time = time.perf_counter() - retriever_start
rag_start = time.perf_counter()
rag_engine = RAGGenerator()
rag_time = time.perf_counter() - rag_start
models_start = time.perf_counter()
models = _build_models(hf_token)
models_time = time.perf_counter() - models_start
state_start = time.perf_counter()
chunk_lookup: dict[str, dict[str, Any]] = {}
for chunk in final_chunks:
metadata = chunk.get("metadata", {})
text = metadata.get("text")
if not text or text in chunk_lookup:
continue
chunk_lookup[text] = {
"title": metadata.get("title", "Untitled"),
"url": metadata.get("url", ""),
"chunk_index": metadata.get("chunk_index"),
}
state["index"] = index
state["retriever"] = retriever
state["rag_engine"] = rag_engine
state["models"] = models
state["chunk_lookup"] = chunk_lookup
state["title_model_ids"] = _parse_title_model_candidates()
state["title_client"] = InferenceClient(token=hf_token)
state_time = time.perf_counter() - state_start
startup_time = time.perf_counter() - startup_start
print(
f"API startup complete | chunks={len(final_chunks)} | "
f"dotenv={dotenv_time:.3f}s | "
f"env={env_time:.3f}s | "
f"index={index_time:.3f}s | "
f"cache_dir={cache_dir} | "
f"force_cache_refresh={force_cache_refresh} | "
f"chunk_source={chunk_source} | "
f"chunk_load={chunk_load_time:.3f}s | "
f"processor={processor_time:.3f}s | "
f"retriever={retriever_time:.3f}s | "
f"rag={rag_time:.3f}s | "
f"models={models_time:.3f}s | "
f"state={state_time:.3f}s | "
f"total={startup_time:.3f}s"
)
@app.get("/health")
def health() -> dict[str, str]:
ready = all(k in state for k in ("index", "retriever", "rag_engine", "models"))
return {"status": "ok" if ready else "starting"}
#title generation endpoint
# is called only once when we create a new chat, after first prompt
@app.post("/predict/title", response_model=TitleResponse)
def suggest_title(payload: TitleRequest) -> TitleResponse:
query = payload.query.strip()
if not query:
raise HTTPException(status_code=400, detail="Query cannot be empty")
fallback_title = _title_from_query(query)
title_client: InferenceClient | None = state.get("title_client")
title_model_ids: list[str] = state.get("title_model_ids", _parse_title_model_candidates())
if title_client is not None:
for title_model_id in title_model_ids:
try:
hf_title = _title_from_hf(query, title_client, title_model_id)
if hf_title:
return TitleResponse(title=hf_title, source=f"hf:{title_model_id}")
except Exception as exc:
err_text = str(exc)
# Provider/model availability differs across HF accounts; skip unsupported models.
if "model_not_supported" in err_text or "not supported by any provider" in err_text:
continue
print(f"Title generation model failed ({title_model_id}): {exc}")
continue
print("Title generation fallback triggered: no title model available/successful")
return TitleResponse(title=fallback_title, source="rule-based")
# Predict endpoint that takes a query and returns an answer along with contexts and metrics
# is called from the frontend when user clicks submits
# Also resolves model based on user selection
@app.post("/predict", response_model=PredictResponse)
def predict(payload: PredictRequest) -> PredictResponse:
req_start = time.perf_counter()
precheck_start = time.perf_counter()
if not state:
raise HTTPException(status_code=503, detail="Service not initialized yet")
query = payload.query.strip()
if not query:
raise HTTPException(status_code=400, detail="Query cannot be empty")
precheck_time = time.perf_counter() - precheck_start
state_access_start = time.perf_counter()
retriever: HybridRetriever = state["retriever"]
index = state["index"]
rag_engine: RAGGenerator = state["rag_engine"]
models: dict[str, Any] = state["models"]
chunk_lookup: dict[str, dict[str, Any]] = state.get("chunk_lookup", {})
state_access_time = time.perf_counter() - state_access_start
model_resolve_start = time.perf_counter()
model_name, model_instance = _resolve_model(payload.model, models)
model_resolve_time = time.perf_counter() - model_resolve_start
retrieval_start = time.perf_counter()
contexts = retriever.search(
query,
index,
mode=payload.mode,
rerank_strategy=payload.rerank_strategy,
use_mmr=True,
top_k=payload.top_k,
final_k=payload.final_k,
verbose=False,
)
retrieval_time = time.perf_counter() - retrieval_start
if not contexts:
raise HTTPException(status_code=404, detail="No context chunks retrieved for this query")
inference_start = time.perf_counter()
answer = rag_engine.get_answer(model_instance, query, contexts, temperature=0.1)
inference_time = time.perf_counter() - inference_start
mapping_start = time.perf_counter()
retrieved_chunks = _build_retrieved_chunks(
contexts=contexts,
chunk_lookup=chunk_lookup,
)
mapping_time = time.perf_counter() - mapping_start
total_time = time.perf_counter() - req_start
print(
f"Predict timing | model={model_name} | mode={payload.mode} | "
f"rerank={payload.rerank_strategy} | precheck={precheck_time:.3f}s | "
f"state_access={state_access_time:.3f}s | model_resolve={model_resolve_time:.3f}s | "
f"retrieval={retrieval_time:.3f}s | inference={inference_time:.3f}s | "
f"context_map={mapping_time:.3f}s | total={total_time:.3f}s"
)
return PredictResponse(
model=model_name,
answer=answer,
contexts=contexts,
retrieved_chunks=retrieved_chunks,
)
# new endpoint for streaming response, allows frontend to render tokens as they come in instead of waiting for full answer
@app.post("/predict/stream")
def predict_stream(payload: PredictRequest) -> StreamingResponse:
req_start = time.perf_counter()
precheck_start = time.perf_counter()
if not state:
raise HTTPException(status_code=503, detail="Service not initialized yet")
query = payload.query.strip()
if not query:
raise HTTPException(status_code=400, detail="Query cannot be empty")
precheck_time = time.perf_counter() - precheck_start
state_access_start = time.perf_counter()
retriever: HybridRetriever = state["retriever"]
index = state["index"]
rag_engine: RAGGenerator = state["rag_engine"]
models: dict[str, Any] = state["models"]
chunk_lookup: dict[str, dict[str, Any]] = state.get("chunk_lookup", {})
state_access_time = time.perf_counter() - state_access_start
model_resolve_start = time.perf_counter()
model_name, model_instance = _resolve_model(payload.model, models)
model_resolve_time = time.perf_counter() - model_resolve_start
retrieval_start = time.perf_counter()
contexts = retriever.search(
query,
index,
mode=payload.mode,
rerank_strategy=payload.rerank_strategy,
use_mmr=True,
top_k=payload.top_k,
final_k=payload.final_k,
verbose=False,
)
retrieval_time = time.perf_counter() - retrieval_start
if not contexts:
raise HTTPException(status_code=404, detail="No context chunks retrieved for this query")
def stream_events():
inference_start = time.perf_counter()
answer_parts: list[str] = []
try:
for token in rag_engine.get_answer_stream(model_instance, query, contexts, temperature=0.1):
answer_parts.append(token)
yield _to_ndjson({"type": "token", "token": token})
inference_time = time.perf_counter() - inference_start
answer = "".join(answer_parts)
retrieved_chunks = _build_retrieved_chunks(
contexts=contexts,
chunk_lookup=chunk_lookup,
)
yield _to_ndjson(
{
"type": "done",
"model": model_name,
"answer": answer,
"contexts": contexts,
"retrieved_chunks": retrieved_chunks,
}
)
except Exception as exc:
yield _to_ndjson({"type": "error", "message": f"Streaming failed: {exc}"})
return StreamingResponse(
stream_events(),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)