Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import openai | |
| from datasets import load_dataset | |
| import logging | |
| import time | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| import torch | |
| import psutil | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize OpenAI API key | |
| openai.api_key = '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' | |
| # Initialize with E5 embedding model | |
| model_name = 'intfloat/e5-base-v2' | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| embedding_model = HuggingFaceEmbeddings(model_name=model_name) | |
| embedding_model.client.to(device) | |
| # Load datasets | |
| datasets = {} | |
| dataset_names = ['covidqa', 'hotpotqa', 'pubmedqa'] | |
| for name in dataset_names: | |
| datasets[name] = load_dataset("rungalileo/ragbench", name, split='train') | |
| logger.info(f"Loaded {name}") | |
| def get_system_metrics(): | |
| return { | |
| 'cpu_percent': psutil.cpu_percent(), | |
| 'memory_percent': psutil.virtual_memory().percent | |
| } | |
| def process_query(query, dataset_choice="all"): | |
| start_time = time.time() | |
| try: | |
| relevant_contexts = [] | |
| search_datasets = [dataset_choice] if dataset_choice != "all" else datasets.keys() | |
| for dataset_name in search_datasets: | |
| if dataset_name in datasets: | |
| documents = datasets[dataset_name]['documents'] | |
| for doc in documents: | |
| # Handle both string and list document types | |
| if isinstance(doc, list): | |
| doc_text = ' '.join(doc) | |
| else: | |
| doc_text = str(doc) | |
| if any(keyword.lower() in doc_text.lower() for keyword in query.split()): | |
| relevant_contexts.append((doc_text, dataset_name)) | |
| context_info = f"From {relevant_contexts[0][1]}: {relevant_contexts[0][0]}" if relevant_contexts else "Searching across datasets..." | |
| response = openai.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=[ | |
| {"role": "system", "content": "You are a knowledgeable expert using E5 embeddings for precise information retrieval."}, | |
| {"role": "user", "content": f"Context: {context_info}\nQuestion: {query}"} | |
| ], | |
| max_tokens=300, | |
| temperature=0.7, | |
| ) | |
| metrics = get_system_metrics() | |
| metrics['processing_time'] = time.time() - start_time | |
| metrics_display = f""" | |
| Processing Time: {metrics['processing_time']:.2f}s | |
| CPU Usage: {metrics['cpu_percent']}% | |
| Memory Usage: {metrics['memory_percent']}% | |
| """ | |
| return response.choices[0].message.content.strip(), metrics_display | |
| except Exception as e: | |
| return str(e), "Performance metrics available on next query" | |
| demo = gr.Interface( | |
| fn=process_query, | |
| inputs=[ | |
| gr.Textbox(label="Question", placeholder="Ask your question here"), | |
| gr.Dropdown( | |
| choices=["all"] + dataset_names, | |
| label="Select Dataset", | |
| value="all" | |
| ) | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Response"), | |
| gr.Textbox(label="Performance Metrics") | |
| ], | |
| title="E5-Powered Knowledge Base", | |
| description="Search across RagBench datasets with performance monitoring" | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch(debug=True) | |