Ayush239's picture
Upload folder using huggingface_hub
00eef43 verified
import os
import json
import pickle
from typing import List, Dict, Tuple, Optional
from pathlib import Path
import numpy as np
from sentence_transformers import SentenceTransformer, CrossEncoder
from rank_bm25 import BM25Okapi
import chromadb
class QueryExpander:
def __init__(self, openai_client):
self.client = openai_client
def expand_query(self, query: str, num_variations: int = 3) -> List[str]:
prompt = f"""Given this user query, generate {num_variations} alternative phrasings that capture the same intent but use different words.
Original query: {query}
Return ONLY a JSON array of alternative queries, nothing else.
Example: ["query1", "query2", "query3"]"""
try:
response = self.client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": prompt}],
temperature=0.7
)
variations = json.loads(response.choices[0].message.content)
return [query] + variations
except Exception as e:
print(f"Query expansion failed: {e}")
return [query]
class HybridRetriever:
def __init__(self, embedding_model: str = "all-MiniLM-L6-v2", reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", data_dir: str = "data"):
self.data_dir = Path(data_dir)
self.data_dir.mkdir(exist_ok=True)
print("Loading embedding model...")
self.embedder = SentenceTransformer(embedding_model)
print("Loading reranker model...")
self.reranker = CrossEncoder(reranker_model)
self.chroma_client = chromadb.PersistentClient(path=str(self.data_dir / "vector_store"))
self.documents: List[Dict] = []
self.bm25: Optional[BM25Okapi] = None
self.collection = None
def chunk_text(self, text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]:
words = text.split()
chunks = []
for i in range(0, len(words), chunk_size - overlap):
chunk = ' '.join(words[i:i + chunk_size])
if chunk:
chunks.append(chunk)
return chunks
def index_documents(self, documents: Dict[str, str], chunk_size: int = 500, overlap: int = 50, collection_name: str = "knowledge_base"):
print(f"Indexing documents with chunk_size={chunk_size}, overlap={overlap}")
all_chunks = []
for doc_id, content in documents.items():
chunks = self.chunk_text(content, chunk_size, overlap)
for idx, chunk in enumerate(chunks):
all_chunks.append({
"id": f"{doc_id}_{idx}",
"text": chunk,
"source": doc_id,
"chunk_idx": idx
})
self.documents = all_chunks
if not all_chunks:
raise ValueError("No text chunks created from documents. Please check your document content.")
print("Building BM25 index...")
tokenized_docs = [doc["text"].lower().split() for doc in all_chunks]
self.bm25 = BM25Okapi(tokenized_docs)
print("Building semantic index...")
try:
self.chroma_client.delete_collection(collection_name)
except:
pass
self.collection = self.chroma_client.create_collection(name=collection_name, metadata={"hnsw:space": "cosine"})
batch_size = 100
for i in range(0, len(all_chunks), batch_size):
batch = all_chunks[i:i + batch_size]
self.collection.add(
documents=[doc["text"] for doc in batch],
ids=[doc["id"] for doc in batch],
metadatas=[{"source": doc["source"], "chunk_idx": doc["chunk_idx"]} for doc in batch]
)
print(f"Indexed {len(all_chunks)} chunks from {len(documents)} documents")
with open(self.data_dir / "bm25_index.pkl", "wb") as f:
pickle.dump((self.bm25, self.documents), f)
def retrieve_bm25(self, query: str, top_k: int = 10) -> List[Tuple[Dict, float]]:
if self.bm25 is None:
return []
tokenized_query = query.lower().split()
scores = self.bm25.get_scores(tokenized_query)
top_indices = np.argsort(scores)[::-1][:top_k]
results = []
for idx in top_indices:
if scores[idx] > 0:
results.append((self.documents[idx], float(scores[idx])))
return results
def retrieve_semantic(self, query: str, top_k: int = 10) -> List[Tuple[Dict, float]]:
if self.collection is None:
return []
results = self.collection.query(query_texts=[query], n_results=top_k)
retrieved = []
for i, doc_id in enumerate(results["ids"][0]):
doc = next((d for d in self.documents if d["id"] == doc_id), None)
if doc:
distance = results["distances"][0][i]
similarity = 1 / (1 + distance)
retrieved.append((doc, similarity))
return retrieved
def retrieve_hybrid(self, query: str, top_k: int = 10, bm25_weight: float = 0.5, semantic_weight: float = 0.5) -> List[Tuple[Dict, float]]:
bm25_results = self.retrieve_bm25(query, top_k * 2)
semantic_results = self.retrieve_semantic(query, top_k * 2)
def normalize_scores(results):
if not results:
return {}
scores = [score for _, score in results]
max_score = max(scores) if scores else 1.0
min_score = min(scores) if scores else 0.0
range_score = max_score - min_score if max_score != min_score else 1.0
return {doc["id"]: (score - min_score) / range_score for doc, score in results}
bm25_scores = normalize_scores(bm25_results)
semantic_scores = normalize_scores(semantic_results)
all_doc_ids = set(bm25_scores.keys()) | set(semantic_scores.keys())
combined_scores = {}
for doc_id in all_doc_ids:
bm25_score = bm25_scores.get(doc_id, 0.0)
semantic_score = semantic_scores.get(doc_id, 0.0)
combined_scores[doc_id] = bm25_weight * bm25_score + semantic_weight * semantic_score
sorted_ids = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
results = []
for doc_id, score in sorted_ids:
doc = next((d for d in self.documents if d["id"] == doc_id), None)
if doc:
results.append((doc, score))
return results
def rerank(self, query: str, documents: List[Tuple[Dict, float]], top_k: int = 5) -> List[Tuple[Dict, float]]:
if not documents:
return []
pairs = [[query, doc["text"]] for doc, _ in documents]
rerank_scores = self.reranker.predict(pairs)
reranked = [(doc, float(score)) for (doc, _), score in zip(documents, rerank_scores)]
reranked.sort(key=lambda x: x[1], reverse=True)
return reranked[:top_k]
def retrieve(self, query: str, method: str = "hybrid_rerank", top_k: int = 5, expand_query: bool = False, query_expander: Optional['QueryExpander'] = None, **kwargs) -> List[Dict]:
queries = [query]
if expand_query and query_expander:
queries = query_expander.expand_query(query)
print(f"Expanded to {len(queries)} queries")
all_results = {}
for q in queries:
if method == "bm25":
results = self.retrieve_bm25(q, top_k * 2)
elif method == "semantic":
results = self.retrieve_semantic(q, top_k * 2)
elif method in ["hybrid", "hybrid_rerank"]:
results = self.retrieve_hybrid(q, top_k * 2, kwargs.get("bm25_weight", 0.5), kwargs.get("semantic_weight", 0.5))
else:
raise ValueError(f"Unknown method: {method}")
for doc, score in results:
doc_id = doc["id"]
if doc_id not in all_results:
all_results[doc_id] = (doc, 0.0)
all_results[doc_id] = (doc, all_results[doc_id][1] + score)
aggregated = list(all_results.values())
aggregated.sort(key=lambda x: x[1], reverse=True)
if "rerank" in method:
print(f"Reranking {len(aggregated)} results...")
aggregated = self.rerank(query, aggregated[:top_k * 3], top_k)
else:
aggregated = aggregated[:top_k]
return [{"retrieval_score": score, **doc} for doc, score in aggregated]
class RAGSystem:
def __init__(self, openai_client, data_dir: str = "data"):
self.client = openai_client
self.retriever = HybridRetriever(data_dir=data_dir)
self.query_expander = QueryExpander(openai_client)
self.data_dir = Path(data_dir)
def load_knowledge_base(self, documents: Dict[str, str], chunk_size: int = 500, overlap: int = 50):
self.retriever.index_documents(documents, chunk_size, overlap)
def generate_answer(self, query: str, context: List[Dict], system_prompt: str) -> str:
context_str = "\n\n".join([f"[Source: {doc['source']}, Chunk {doc['chunk_idx']}]\n{doc['text']}" for doc in context])
augmented_prompt = f"""{system_prompt}
## Retrieved Context:
{context_str}
## User Query:
{query}
Please answer the query based on the context provided above."""
messages = [{"role": "user", "content": augmented_prompt}]
response = self.client.chat.completions.create(model="gpt-4o-mini", messages=messages, temperature=0.7)
return response.choices[0].message.content
def query(self, query: str, system_prompt: str, method: str = "hybrid_rerank", top_k: int = 5, expand_query: bool = False, **kwargs) -> Dict:
context = self.retriever.retrieve(query, method=method, top_k=top_k, expand_query=expand_query, query_expander=self.query_expander if expand_query else None, **kwargs)
answer = self.generate_answer(query, context, system_prompt)
return {"answer": answer, "context": context, "method": method, "query": query}