Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import zipfile | |
| # --- 1. SQLITE FIX --- | |
| try: | |
| __import__('pysqlite3') | |
| sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') | |
| except ImportError: | |
| pass | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings | |
| from langchain_chroma import Chroma | |
| from typing import Dict, Any, List | |
| # --- 2. UNZIP & AUTO-DETECT PATH --- | |
| print("β³ Checking for Database...") | |
| # Unzip if the zip exists | |
| if os.path.exists("./chroma_db.zip"): | |
| print("π¦ Found zip file! Unzipping...") | |
| with zipfile.ZipFile("./chroma_db.zip", 'r') as zip_ref: | |
| zip_ref.extractall(".") | |
| print("β Unzip complete.") | |
| # SMART DETECTION: Find where the database went | |
| db_path = "" | |
| if os.path.exists("./chroma_db/chroma.sqlite3"): | |
| # Case A: It's inside the folder (Perfect) | |
| db_path = "./chroma_db" | |
| print(f"π Found database in folder: {db_path}") | |
| elif os.path.exists("./chroma.sqlite3"): | |
| # Case B: It spilled into the root directory | |
| db_path = "." | |
| print(f"π Found database in root directory: {db_path}") | |
| elif os.path.exists("./content/chroma_db/chroma.sqlite3"): | |
| # Case C: It's inside a 'content' folder (Common Colab issue) | |
| db_path = "./content/chroma_db" | |
| print(f"π Found database in content folder: {db_path}") | |
| else: | |
| # Case D: Panic | |
| # Let's list the files to debug | |
| print("β ERROR: Cannot find chroma.sqlite3. Current files in folder:") | |
| print(os.listdir(".")) | |
| raise ValueError("Could not find the database file after unzipping!") | |
| # --- 3. MODEL SETUP --- | |
| print("β³ Loading Embeddings...") | |
| embedding_function = HuggingFaceEmbeddings( | |
| model_name="nomic-ai/nomic-embed-text-v1.5", | |
| model_kwargs={"trust_remote_code": True, "device": "cpu"} | |
| ) | |
| print(f"β³ Loading Database from {db_path}...") | |
| vector_db = Chroma( | |
| persist_directory=db_path, | |
| embedding_function=embedding_function | |
| ) | |
| print("β³ Loading TinyLlama Model...") | |
| model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained(model_id) | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=256, | |
| repetition_penalty=1.15, | |
| temperature=0.1, | |
| do_sample=True | |
| ) | |
| llm = HuggingFacePipeline(pipeline=pipe) | |
| # --- 4. RAG CHAIN --- | |
| class ManualQAChain: | |
| def __init__(self, vector_store: Chroma, llm_pipeline: HuggingFacePipeline): | |
| self.retriever = vector_store.as_retriever(search_kwargs={"k": 2}) | |
| self.llm = llm_pipeline | |
| def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]: | |
| query = inputs.get("query", "") | |
| # Retrieval | |
| docs = self.retriever.invoke(query) | |
| context = "\n\n".join([d.page_content for d in docs]) if docs else "No context found." | |
| # Prompt | |
| prompt = f"""<|system|> | |
| You are a helpful medical assistant. Use ONLY the context below. | |
| If the answer is not in the context, say "I cannot find the answer." | |
| Context: | |
| {context[:2000]} | |
| </s> | |
| <|user|> | |
| {query} | |
| </s> | |
| <|assistant|> | |
| """ | |
| # Generation | |
| response = self.llm.invoke(prompt) | |
| text = response[0]['generated_text'] if isinstance(response, list) else str(response) | |
| if "<|assistant|>" in text: | |
| final_answer = text.split("<|assistant|>")[-1].strip() | |
| else: | |
| final_answer = text.strip() | |
| return {"result": final_answer, "source_documents": docs} | |
| # Initialize | |
| qa_chain = ManualQAChain(vector_db, llm) | |
| # --- 5. UI --- | |
| def medical_rag_chat(message, history): | |
| if not message: return "Please ask a question." | |
| try: | |
| response = qa_chain.invoke({"query": message}) | |
| sources = "\n\n---\n**Retrieved Context:**\n" | |
| if response.get('source_documents'): | |
| for i, doc in enumerate(response['source_documents']): | |
| topic = doc.metadata.get('focus_area', 'Protocol') | |
| sources += f"**{i+1}. [{topic}]** {doc.page_content[:300]}...\n" | |
| else: | |
| sources += "(No context found)" | |
| return response['result'] + sources | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| demo = gr.ChatInterface( | |
| fn=medical_rag_chat, | |
| title="Cardio-Oncology RAG Assistant", | |
| description="TinyLlama-1.1B + MedQuAD RAG", | |
| examples=["What are the symptoms of Lung Cancer?", "Who is at risk for Heart Failure?"] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |