Spaces:
Running
Running
| # retriever.py | |
| # This file handles the setup of embeddings, vector stores, and the ensemble retriever. | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_chroma import Chroma | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain.retrievers import EnsembleRetriever | |
| from config import ( | |
| MODEL_NAME, MODEL_KWARGS, ENCODE_KWARGS, VECTOR_STORE_DIRECTORY, | |
| DENSE_RETRIEVER_K, KEYWORD_RETRIEVER_K, ENSEMBLE_WEIGHTS | |
| ) | |
| def get_embedding_function(): | |
| """Initializes and returns the HuggingFace embedding model.""" | |
| return HuggingFaceEmbeddings( | |
| model_name=MODEL_NAME, | |
| model_kwargs=MODEL_KWARGS, | |
| encode_kwargs=ENCODE_KWARGS | |
| ) | |
| def get_vector_store(embedding_function): | |
| """Initializes and returns the Chroma vector store.""" | |
| return Chroma( | |
| embedding_function=embedding_function, | |
| persist_directory=VECTOR_STORE_DIRECTORY | |
| ) | |
| def get_ensemble_retriever(): | |
| """ | |
| Creates and returns an ensemble retriever combining dense and keyword-based search. | |
| """ | |
| print("Initializing embeddings and vector store...") | |
| embeddings = get_embedding_function() | |
| vector_store = get_vector_store(embeddings) | |
| dense_vector_retriever = vector_store.as_retriever(k=DENSE_RETRIEVER_K) | |
| print("Loading documents for BM25 retriever...") | |
| ids = vector_store.get().get("ids", []) | |
| if not ids: | |
| all_documents = [] | |
| else: | |
| all_documents = vector_store.get_by_ids(ids) | |
| keyword_search_retriever = BM25Retriever.from_documents( | |
| documents=all_documents, k=KEYWORD_RETRIEVER_K | |
| ) if all_documents else None | |
| if keyword_search_retriever: | |
| print("Creating ensemble retriever...") | |
| ensemble_retriever = EnsembleRetriever( | |
| retrievers=[dense_vector_retriever, keyword_search_retriever], | |
| weights=ENSEMBLE_WEIGHTS | |
| ) | |
| else: | |
| print("Creating dense-only retriever...") | |
| ensemble_retriever = dense_vector_retriever | |
| print("Retriever setup complete.") | |
| return ensemble_retriever |