vibbabackend / retriever.py
prestiva's picture
UPDATE: book3
92e030f
# retriever.py
# This file handles the setup of embeddings, vector stores, and the ensemble retriever.
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from config import (
MODEL_NAME, MODEL_KWARGS, ENCODE_KWARGS, VECTOR_STORE_DIRECTORY,
DENSE_RETRIEVER_K, KEYWORD_RETRIEVER_K, ENSEMBLE_WEIGHTS
)
def get_embedding_function():
"""Initializes and returns the HuggingFace embedding model."""
return HuggingFaceEmbeddings(
model_name=MODEL_NAME,
model_kwargs=MODEL_KWARGS,
encode_kwargs=ENCODE_KWARGS
)
def get_vector_store(embedding_function):
"""Initializes and returns the Chroma vector store."""
return Chroma(
embedding_function=embedding_function,
persist_directory=VECTOR_STORE_DIRECTORY
)
def get_ensemble_retriever():
"""
Creates and returns an ensemble retriever combining dense and keyword-based search.
"""
print("Initializing embeddings and vector store...")
embeddings = get_embedding_function()
vector_store = get_vector_store(embeddings)
dense_vector_retriever = vector_store.as_retriever(k=DENSE_RETRIEVER_K)
print("Loading documents for BM25 retriever...")
ids = vector_store.get().get("ids", [])
if not ids:
all_documents = []
else:
all_documents = vector_store.get_by_ids(ids)
keyword_search_retriever = BM25Retriever.from_documents(
documents=all_documents, k=KEYWORD_RETRIEVER_K
) if all_documents else None
if keyword_search_retriever:
print("Creating ensemble retriever...")
ensemble_retriever = EnsembleRetriever(
retrievers=[dense_vector_retriever, keyword_search_retriever],
weights=ENSEMBLE_WEIGHTS
)
else:
print("Creating dense-only retriever...")
ensemble_retriever = dense_vector_retriever
print("Retriever setup complete.")
return ensemble_retriever