duniele's picture
Update app.py
b93d070 verified
import os
import sys
import zipfile
# --- 1. SQLITE FIX ---
try:
__import__('pysqlite3')
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
except ImportError:
pass
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings
from langchain_chroma import Chroma
from typing import Dict, Any, List
# --- 2. UNZIP & AUTO-DETECT PATH ---
print("⏳ Checking for Database...")
# Unzip if the zip exists
if os.path.exists("./chroma_db.zip"):
print("πŸ“¦ Found zip file! Unzipping...")
with zipfile.ZipFile("./chroma_db.zip", 'r') as zip_ref:
zip_ref.extractall(".")
print("βœ… Unzip complete.")
# SMART DETECTION: Find where the database went
db_path = ""
if os.path.exists("./chroma_db/chroma.sqlite3"):
# Case A: It's inside the folder (Perfect)
db_path = "./chroma_db"
print(f"πŸ“‚ Found database in folder: {db_path}")
elif os.path.exists("./chroma.sqlite3"):
# Case B: It spilled into the root directory
db_path = "."
print(f"πŸ“‚ Found database in root directory: {db_path}")
elif os.path.exists("./content/chroma_db/chroma.sqlite3"):
# Case C: It's inside a 'content' folder (Common Colab issue)
db_path = "./content/chroma_db"
print(f"πŸ“‚ Found database in content folder: {db_path}")
else:
# Case D: Panic
# Let's list the files to debug
print("❌ ERROR: Cannot find chroma.sqlite3. Current files in folder:")
print(os.listdir("."))
raise ValueError("Could not find the database file after unzipping!")
# --- 3. MODEL SETUP ---
print("⏳ Loading Embeddings...")
embedding_function = HuggingFaceEmbeddings(
model_name="nomic-ai/nomic-embed-text-v1.5",
model_kwargs={"trust_remote_code": True, "device": "cpu"}
)
print(f"⏳ Loading Database from {db_path}...")
vector_db = Chroma(
persist_directory=db_path,
embedding_function=embedding_function
)
print("⏳ Loading TinyLlama Model...")
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=256,
repetition_penalty=1.15,
temperature=0.1,
do_sample=True
)
llm = HuggingFacePipeline(pipeline=pipe)
# --- 4. RAG CHAIN ---
class ManualQAChain:
def __init__(self, vector_store: Chroma, llm_pipeline: HuggingFacePipeline):
self.retriever = vector_store.as_retriever(search_kwargs={"k": 2})
self.llm = llm_pipeline
def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
query = inputs.get("query", "")
# Retrieval
docs = self.retriever.invoke(query)
context = "\n\n".join([d.page_content for d in docs]) if docs else "No context found."
# Prompt
prompt = f"""<|system|>
You are a helpful medical assistant. Use ONLY the context below.
If the answer is not in the context, say "I cannot find the answer."
Context:
{context[:2000]}
</s>
<|user|>
{query}
</s>
<|assistant|>
"""
# Generation
response = self.llm.invoke(prompt)
text = response[0]['generated_text'] if isinstance(response, list) else str(response)
if "<|assistant|>" in text:
final_answer = text.split("<|assistant|>")[-1].strip()
else:
final_answer = text.strip()
return {"result": final_answer, "source_documents": docs}
# Initialize
qa_chain = ManualQAChain(vector_db, llm)
# --- 5. UI ---
def medical_rag_chat(message, history):
if not message: return "Please ask a question."
try:
response = qa_chain.invoke({"query": message})
sources = "\n\n---\n**Retrieved Context:**\n"
if response.get('source_documents'):
for i, doc in enumerate(response['source_documents']):
topic = doc.metadata.get('focus_area', 'Protocol')
sources += f"**{i+1}. [{topic}]** {doc.page_content[:300]}...\n"
else:
sources += "(No context found)"
return response['result'] + sources
except Exception as e:
return f"Error: {str(e)}"
demo = gr.ChatInterface(
fn=medical_rag_chat,
title="Cardio-Oncology RAG Assistant",
description="TinyLlama-1.1B + MedQuAD RAG",
examples=["What are the symptoms of Lung Cancer?", "Who is at risk for Heart Failure?"]
)
if __name__ == "__main__":
demo.launch()