|
|
""" |
|
|
Advanced RAG techniques for improved retrieval and generation |
|
|
Includes: Query Expansion, Reranking, Contextual Compression, Hybrid Search |
|
|
""" |
|
|
|
|
|
from typing import List, Dict, Optional, Tuple |
|
|
import numpy as np |
|
|
from dataclasses import dataclass |
|
|
import re |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class RetrievedDocument: |
|
|
"""Document retrieved from vector database""" |
|
|
id: str |
|
|
text: str |
|
|
confidence: float |
|
|
metadata: Dict |
|
|
|
|
|
|
|
|
class AdvancedRAG: |
|
|
"""Advanced RAG system with modern techniques""" |
|
|
|
|
|
def __init__(self, embedding_service, qdrant_service): |
|
|
self.embedding_service = embedding_service |
|
|
self.qdrant_service = qdrant_service |
|
|
|
|
|
def expand_query(self, query: str) -> List[str]: |
|
|
""" |
|
|
Expand query with related terms and variations |
|
|
Simple rule-based expansion for Vietnamese queries |
|
|
""" |
|
|
queries = [query] |
|
|
|
|
|
|
|
|
|
|
|
question_words = ['ai', 'gì', 'nào', 'đâu', 'khi nào', 'như thế nào', |
|
|
'tại sao', 'có', 'là', 'được', 'không'] |
|
|
|
|
|
query_lower = query.lower() |
|
|
for qw in question_words: |
|
|
if qw in query_lower: |
|
|
variant = query_lower.replace(qw, '').strip() |
|
|
if variant and variant != query_lower: |
|
|
queries.append(variant) |
|
|
|
|
|
|
|
|
words = query.split() |
|
|
if len(words) > 3: |
|
|
|
|
|
key_phrases = ' '.join(words[1:]) if words[0].lower() in question_words else ' '.join(words[:3]) |
|
|
if key_phrases not in queries: |
|
|
queries.append(key_phrases) |
|
|
|
|
|
return queries[:3] |
|
|
|
|
|
def multi_query_retrieval( |
|
|
self, |
|
|
query: str, |
|
|
top_k: int = 5, |
|
|
score_threshold: float = 0.5 |
|
|
) -> List[RetrievedDocument]: |
|
|
""" |
|
|
Retrieve documents using multiple query variations |
|
|
Combines results from all query variations |
|
|
""" |
|
|
expanded_queries = self.expand_query(query) |
|
|
|
|
|
all_results = {} |
|
|
|
|
|
for q in expanded_queries: |
|
|
|
|
|
query_embedding = self.embedding_service.encode_text(q) |
|
|
|
|
|
|
|
|
results = self.qdrant_service.search( |
|
|
query_embedding=query_embedding, |
|
|
limit=top_k, |
|
|
score_threshold=score_threshold |
|
|
) |
|
|
|
|
|
|
|
|
for result in results: |
|
|
doc_id = result["id"] |
|
|
if doc_id not in all_results or result["confidence"] > all_results[doc_id].confidence: |
|
|
all_results[doc_id] = RetrievedDocument( |
|
|
id=doc_id, |
|
|
text=result["metadata"].get("text", ""), |
|
|
confidence=result["confidence"], |
|
|
metadata=result["metadata"] |
|
|
) |
|
|
|
|
|
|
|
|
sorted_results = sorted(all_results.values(), key=lambda x: x.confidence, reverse=True) |
|
|
return sorted_results[:top_k] |
|
|
|
|
|
def rerank_documents( |
|
|
self, |
|
|
query: str, |
|
|
documents: List[RetrievedDocument], |
|
|
use_cross_encoder: bool = False |
|
|
) -> List[RetrievedDocument]: |
|
|
""" |
|
|
Rerank documents based on semantic similarity |
|
|
Simple reranking using embedding similarity (can be upgraded to cross-encoder) |
|
|
""" |
|
|
if not documents: |
|
|
return documents |
|
|
|
|
|
|
|
|
query_embedding = self.embedding_service.encode_text(query) |
|
|
|
|
|
reranked = [] |
|
|
for doc in documents: |
|
|
|
|
|
doc_embedding = self.embedding_service.encode_text(doc.text) |
|
|
|
|
|
|
|
|
similarity = np.dot(query_embedding.flatten(), doc_embedding.flatten()) |
|
|
|
|
|
|
|
|
new_score = 0.6 * similarity + 0.4 * doc.confidence |
|
|
|
|
|
reranked.append(RetrievedDocument( |
|
|
id=doc.id, |
|
|
text=doc.text, |
|
|
confidence=float(new_score), |
|
|
metadata=doc.metadata |
|
|
)) |
|
|
|
|
|
|
|
|
reranked.sort(key=lambda x: x.confidence, reverse=True) |
|
|
return reranked |
|
|
|
|
|
def compress_context( |
|
|
self, |
|
|
query: str, |
|
|
documents: List[RetrievedDocument], |
|
|
max_tokens: int = 500 |
|
|
) -> List[RetrievedDocument]: |
|
|
""" |
|
|
Compress context to most relevant parts |
|
|
Remove redundant information and keep only relevant sentences |
|
|
""" |
|
|
compressed_docs = [] |
|
|
|
|
|
for doc in documents: |
|
|
|
|
|
sentences = self._split_sentences(doc.text) |
|
|
|
|
|
|
|
|
scored_sentences = [] |
|
|
query_words = set(query.lower().split()) |
|
|
|
|
|
for sent in sentences: |
|
|
sent_words = set(sent.lower().split()) |
|
|
|
|
|
overlap = len(query_words & sent_words) |
|
|
if overlap > 0: |
|
|
scored_sentences.append((sent, overlap)) |
|
|
|
|
|
|
|
|
scored_sentences.sort(key=lambda x: x[1], reverse=True) |
|
|
|
|
|
|
|
|
compressed_text = "" |
|
|
word_count = 0 |
|
|
for sent, score in scored_sentences: |
|
|
sent_words = len(sent.split()) |
|
|
if word_count + sent_words <= max_tokens: |
|
|
compressed_text += sent + " " |
|
|
word_count += sent_words |
|
|
else: |
|
|
break |
|
|
|
|
|
|
|
|
if not compressed_text.strip(): |
|
|
compressed_text = doc.text[:max_tokens * 5] |
|
|
|
|
|
compressed_docs.append(RetrievedDocument( |
|
|
id=doc.id, |
|
|
text=compressed_text.strip(), |
|
|
confidence=doc.confidence, |
|
|
metadata=doc.metadata |
|
|
)) |
|
|
|
|
|
return compressed_docs |
|
|
|
|
|
def _split_sentences(self, text: str) -> List[str]: |
|
|
"""Split text into sentences (Vietnamese-aware)""" |
|
|
|
|
|
sentences = re.split(r'[.!?]+', text) |
|
|
return [s.strip() for s in sentences if s.strip()] |
|
|
|
|
|
def hybrid_rag_pipeline( |
|
|
self, |
|
|
query: str, |
|
|
top_k: int = 5, |
|
|
score_threshold: float = 0.5, |
|
|
use_reranking: bool = True, |
|
|
use_compression: bool = True, |
|
|
max_context_tokens: int = 500 |
|
|
) -> Tuple[List[RetrievedDocument], Dict]: |
|
|
""" |
|
|
Complete advanced RAG pipeline |
|
|
1. Multi-query retrieval |
|
|
2. Reranking |
|
|
3. Contextual compression |
|
|
""" |
|
|
stats = { |
|
|
"original_query": query, |
|
|
"expanded_queries": [], |
|
|
"initial_results": 0, |
|
|
"after_rerank": 0, |
|
|
"after_compression": 0 |
|
|
} |
|
|
|
|
|
|
|
|
expanded_queries = self.expand_query(query) |
|
|
stats["expanded_queries"] = expanded_queries |
|
|
|
|
|
documents = self.multi_query_retrieval( |
|
|
query=query, |
|
|
top_k=top_k * 2, |
|
|
score_threshold=score_threshold |
|
|
) |
|
|
stats["initial_results"] = len(documents) |
|
|
|
|
|
|
|
|
if use_reranking and documents: |
|
|
documents = self.rerank_documents(query, documents) |
|
|
documents = documents[:top_k] |
|
|
stats["after_rerank"] = len(documents) |
|
|
|
|
|
|
|
|
if use_compression and documents: |
|
|
documents = self.compress_context( |
|
|
query=query, |
|
|
documents=documents, |
|
|
max_tokens=max_context_tokens |
|
|
) |
|
|
stats["after_compression"] = len(documents) |
|
|
|
|
|
return documents, stats |
|
|
|
|
|
def format_context_for_llm( |
|
|
self, |
|
|
documents: List[RetrievedDocument], |
|
|
include_metadata: bool = True |
|
|
) -> str: |
|
|
""" |
|
|
Format retrieved documents into context string for LLM |
|
|
Uses better structure for improved LLM understanding |
|
|
""" |
|
|
if not documents: |
|
|
return "" |
|
|
|
|
|
context_parts = ["RELEVANT CONTEXT:\n"] |
|
|
|
|
|
for i, doc in enumerate(documents, 1): |
|
|
context_parts.append(f"\n--- Document {i} (Relevance: {doc.confidence:.2%}) ---") |
|
|
context_parts.append(doc.text) |
|
|
|
|
|
if include_metadata and doc.metadata: |
|
|
|
|
|
meta_str = [] |
|
|
for key, value in doc.metadata.items(): |
|
|
if key not in ['text', 'texts'] and value: |
|
|
meta_str.append(f"{key}: {value}") |
|
|
if meta_str: |
|
|
context_parts.append(f"[Metadata: {', '.join(meta_str)}]") |
|
|
|
|
|
context_parts.append("\n--- End of Context ---\n") |
|
|
return "\n".join(context_parts) |
|
|
|
|
|
def build_rag_prompt( |
|
|
self, |
|
|
query: str, |
|
|
context: str, |
|
|
system_message: str = "You are a helpful AI assistant." |
|
|
) -> str: |
|
|
""" |
|
|
Build optimized RAG prompt for LLM |
|
|
Uses best practices for prompt engineering |
|
|
""" |
|
|
prompt_template = f"""{system_message} |
|
|
|
|
|
{context} |
|
|
|
|
|
INSTRUCTIONS: |
|
|
1. Answer the user's question using ONLY the information provided in the context above |
|
|
2. If the context doesn't contain relevant information, say "Tôi không tìm thấy thông tin liên quan trong dữ liệu." |
|
|
3. Cite relevant parts of the context when answering |
|
|
4. Be concise and accurate |
|
|
5. Answer in Vietnamese if the question is in Vietnamese |
|
|
|
|
|
USER QUESTION: {query} |
|
|
|
|
|
YOUR ANSWER:""" |
|
|
|
|
|
return prompt_template |
|
|
|