Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import tempfile | |
| import os | |
| import json | |
| from typing import List, Dict, Any, Optional, Tuple | |
| import traceback | |
| # Import our modules | |
| from src.document_processor import DocumentProcessor | |
| from src.llm_extractor import LLMExtractor | |
| from src.graph_builder import GraphBuilder | |
| from src.visualizer import GraphVisualizer | |
| from config.settings import Config | |
| # Page config | |
| st.set_page_config( | |
| page_title="Knowledge Graph Extraction", | |
| page_icon="πΈοΈ", | |
| layout="wide" | |
| ) | |
| # Initialize components | |
| def initialize_components(): | |
| config = Config() | |
| doc_processor = DocumentProcessor() | |
| llm_extractor = LLMExtractor() | |
| graph_builder = GraphBuilder() | |
| visualizer = GraphVisualizer() | |
| return config, doc_processor, llm_extractor, graph_builder, visualizer | |
| config, doc_processor, llm_extractor, graph_builder, visualizer = initialize_components() | |
| def process_uploaded_files(uploaded_files, api_key, batch_mode, layout_type, | |
| show_labels, show_edge_labels, min_importance, entity_types_filter): | |
| """Process uploaded files and extract knowledge graph.""" | |
| try: | |
| # Update API key | |
| if api_key.strip(): | |
| config.OPENROUTER_API_KEY = api_key.strip() | |
| llm_extractor.config.OPENROUTER_API_KEY = api_key.strip() | |
| llm_extractor.headers["Authorization"] = f"Bearer {api_key.strip()}" | |
| if not config.OPENROUTER_API_KEY: | |
| st.error("β OpenRouter API key is required") | |
| return None | |
| if not uploaded_files: | |
| st.error("β Please upload at least one file") | |
| return None | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| status_text.text("Loading documents...") | |
| progress_bar.progress(0.1) | |
| # Save uploaded files to temporary location | |
| file_paths = [] | |
| for uploaded_file in uploaded_files: | |
| # Create temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=f"_{uploaded_file.name}") as tmp_file: | |
| tmp_file.write(uploaded_file.getvalue()) | |
| file_paths.append(tmp_file.name) | |
| # Process documents | |
| doc_results = doc_processor.process_documents(file_paths, batch_mode) | |
| # Clean up temporary files | |
| for file_path in file_paths: | |
| try: | |
| os.unlink(file_path) | |
| except: | |
| pass | |
| # Check for errors | |
| failed_files = [r for r in doc_results if r['status'] == 'error'] | |
| if failed_files: | |
| error_msg = "Failed to process files:\n" + "\n".join([f"- {r['file_path']}: {r['error']}" for r in failed_files]) | |
| if len(failed_files) == len(doc_results): | |
| st.error(f"β {error_msg}") | |
| return None | |
| status_text.text("Extracting entities and relationships...") | |
| progress_bar.progress(0.3) | |
| # Extract entities and relationships | |
| all_entities = [] | |
| all_relationships = [] | |
| extraction_errors = [] | |
| for doc_result in doc_results: | |
| if doc_result['status'] == 'success': | |
| extraction_result = llm_extractor.process_chunks(doc_result['chunks']) | |
| if extraction_result.get('errors'): | |
| extraction_errors.extend(extraction_result['errors']) | |
| all_entities.extend(extraction_result.get('entities', [])) | |
| all_relationships.extend(extraction_result.get('relationships', [])) | |
| if not all_entities: | |
| error_msg = "No entities extracted from documents" | |
| if extraction_errors: | |
| error_msg += f"\nExtraction errors: {'; '.join(extraction_errors[:3])}" | |
| st.error(f"β {error_msg}") | |
| return None | |
| status_text.text("Building knowledge graph...") | |
| progress_bar.progress(0.6) | |
| # Build graph | |
| graph = graph_builder.build_graph(all_entities, all_relationships) | |
| if not graph.nodes(): | |
| st.error("β No valid knowledge graph could be built") | |
| return None | |
| status_text.text("Applying filters...") | |
| progress_bar.progress(0.7) | |
| # Apply filters | |
| filtered_graph = graph | |
| if entity_types_filter: | |
| filtered_graph = graph_builder.filter_graph( | |
| entity_types=entity_types_filter, | |
| min_importance=min_importance | |
| ) | |
| elif min_importance > 0: | |
| filtered_graph = graph_builder.filter_graph(min_importance=min_importance) | |
| if not filtered_graph.nodes(): | |
| st.error("β No entities remain after applying filters") | |
| return None | |
| status_text.text("Generating visualizations...") | |
| progress_bar.progress(0.8) | |
| # Generate graph visualization | |
| graph_image_path = visualizer.visualize_graph( | |
| filtered_graph, | |
| layout_type=layout_type, | |
| show_labels=show_labels, | |
| show_edge_labels=show_edge_labels | |
| ) | |
| # Get statistics | |
| stats = graph_builder.get_graph_statistics() | |
| stats_summary = visualizer.create_statistics_summary(filtered_graph, stats) | |
| # Get entity list | |
| entity_list = visualizer.create_entity_list(filtered_graph) | |
| # Get central nodes | |
| central_nodes = graph_builder.get_central_nodes() | |
| central_nodes_text = "## Most Central Entities\n\n" | |
| for i, (node, score) in enumerate(central_nodes, 1): | |
| central_nodes_text += f"{i}. **{node}** (centrality: {score:.3f})\n" | |
| status_text.text("Complete!") | |
| progress_bar.progress(1.0) | |
| # Success message | |
| success_msg = f"β Successfully processed {len([r for r in doc_results if r['status'] == 'success'])} document(s)" | |
| if failed_files: | |
| success_msg += f"\nβ οΈ {len(failed_files)} file(s) failed to process" | |
| if extraction_errors: | |
| success_msg += f"\nβ οΈ {len(extraction_errors)} extraction error(s) occurred" | |
| return { | |
| 'success_msg': success_msg, | |
| 'graph_image_path': graph_image_path, | |
| 'stats_summary': stats_summary, | |
| 'entity_list': entity_list, | |
| 'central_nodes_text': central_nodes_text, | |
| 'graph': filtered_graph | |
| } | |
| except Exception as e: | |
| st.error(f"β Error: {str(e)}") | |
| st.error(f"Full traceback:\n{traceback.format_exc()}") | |
| return None | |
| # Main app | |
| def main(): | |
| st.title("πΈοΈ Knowledge Graph Extraction") | |
| st.markdown(""" | |
| Upload documents and extract knowledge graphs using LLMs via OpenRouter. | |
| Supports PDF, TXT, DOCX, and JSON files. | |
| """) | |
| # Sidebar for configuration | |
| with st.sidebar: | |
| st.header("π Document Upload") | |
| uploaded_files = st.file_uploader( | |
| "Choose files", | |
| type=['pdf', 'txt', 'docx', 'json'], | |
| accept_multiple_files=True | |
| ) | |
| batch_mode = st.checkbox( | |
| "Batch Processing Mode", | |
| value=False, | |
| help="Process multiple files together" | |
| ) | |
| st.header("π API Configuration") | |
| api_key = st.text_input( | |
| "OpenRouter API Key", | |
| type="password", | |
| placeholder="Enter your OpenRouter API key", | |
| help="Get your key at openrouter.ai" | |
| ) | |
| st.header("ποΈ Visualization Settings") | |
| layout_type = st.selectbox( | |
| "Layout Algorithm", | |
| options=visualizer.get_layout_options(), | |
| index=0 | |
| ) | |
| show_labels = st.checkbox("Show Node Labels", value=True) | |
| show_edge_labels = st.checkbox("Show Edge Labels", value=False) | |
| st.header("π Filtering Options") | |
| min_importance = st.slider( | |
| "Minimum Entity Importance", | |
| min_value=0.0, | |
| max_value=1.0, | |
| value=0.3, | |
| step=0.1 | |
| ) | |
| entity_types_filter = st.multiselect( | |
| "Entity Types Filter", | |
| options=[], | |
| help="Filter will be populated after processing" | |
| ) | |
| process_button = st.button("π Extract Knowledge Graph", type="primary") | |
| # Main content area | |
| if process_button and uploaded_files: | |
| with st.spinner("Processing..."): | |
| result = process_uploaded_files( | |
| uploaded_files, api_key, batch_mode, layout_type, | |
| show_labels, show_edge_labels, min_importance, entity_types_filter | |
| ) | |
| if result: | |
| # Store results in session state | |
| st.session_state['result'] = result | |
| # Display success message | |
| st.success(result['success_msg']) | |
| # Create tabs for results | |
| tab1, tab2, tab3, tab4 = st.tabs(["π Graph Visualization", "π Statistics", "π Entities", "π― Central Nodes"]) | |
| with tab1: | |
| if result['graph_image_path'] and os.path.exists(result['graph_image_path']): | |
| st.image(result['graph_image_path'], caption="Knowledge Graph", use_column_width=True) | |
| else: | |
| st.error("Failed to generate graph visualization") | |
| with tab2: | |
| st.markdown(result['stats_summary']) | |
| with tab3: | |
| st.markdown(result['entity_list']) | |
| with tab4: | |
| st.markdown(result['central_nodes_text']) | |
| # Export options | |
| st.header("πΎ Export Options") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| export_format = st.selectbox( | |
| "Export Format", | |
| options=["json", "graphml", "gexf"], | |
| index=0 | |
| ) | |
| with col2: | |
| if st.button("π₯ Export Graph"): | |
| try: | |
| export_data = graph_builder.export_graph(export_format) | |
| st.text_area("Export Data", value=export_data, height=300) | |
| # Download button | |
| st.download_button( | |
| label=f"Download {export_format.upper()} file", | |
| data=export_data, | |
| file_name=f"knowledge_graph.{export_format}", | |
| mime="application/octet-stream" | |
| ) | |
| except Exception as e: | |
| st.error(f"Export failed: {str(e)}") | |
| elif process_button and not uploaded_files: | |
| st.warning("Please upload at least one file before processing.") | |
| # Instructions | |
| st.header("π Instructions") | |
| with st.expander("How to use this app"): | |
| st.markdown(""" | |
| 1. **Upload Documents**: Select one or more files (PDF, TXT, DOCX, JSON) using the file uploader in the sidebar | |
| 2. **Enter API Key**: Get a free API key from [OpenRouter](https://openrouter.ai) and enter it in the sidebar | |
| 3. **Configure Settings**: Adjust visualization and filtering options in the sidebar | |
| 4. **Extract Graph**: Click the "Extract Knowledge Graph" button and wait for processing | |
| 5. **Explore Results**: View the graph, statistics, and entity details in the tabs | |
| 6. **Export**: Download the graph data in various formats | |
| """) | |
| with st.expander("Features"): | |
| st.markdown(""" | |
| - **Multi-format Support**: PDF, TXT, DOCX, JSON files | |
| - **Batch Processing**: Process multiple documents together | |
| - **Smart Extraction**: Uses LLM to identify important entities and relationships | |
| - **Interactive Filtering**: Filter by entity type and importance | |
| - **Multiple Layouts**: Various graph layout algorithms | |
| - **Export Options**: JSON, GraphML, GEXF formats | |
| - **Free Models**: Uses cost-effective OpenRouter models | |
| """) | |
| with st.expander("Notes"): | |
| st.markdown(""" | |
| - File size limit: 10MB per file | |
| - Free OpenRouter models are used to minimize costs | |
| - Processing time depends on document size and complexity | |
| """) | |
| if __name__ == "__main__": | |
| main() | |