Spaces:
Sleeping
Sleeping
| import matplotlib | |
| matplotlib.use('Agg') # Use non-interactive backend to avoid GUI issues | |
| import matplotlib.pyplot as plt | |
| import networkx as nx | |
| import numpy as np | |
| from typing import Dict, List, Any, Tuple, Optional | |
| import json | |
| import io | |
| import base64 | |
| import tempfile | |
| import os | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| from pyvis.network import Network | |
| class GraphVisualizer: | |
| def __init__(self): | |
| self.color_map = { | |
| 'PERSON': '#FF6B6B', | |
| 'ORGANIZATION': '#4ECDC4', | |
| 'LOCATION': '#45B7D1', | |
| 'CONCEPT': '#96CEB4', | |
| 'EVENT': '#FFEAA7', | |
| 'OBJECT': '#DDA0DD', | |
| 'UNKNOWN': '#95A5A6' | |
| } | |
| def visualize_graph(self, | |
| graph: nx.DiGraph, | |
| layout_type: str = "spring", | |
| show_labels: bool = True, | |
| show_edge_labels: bool = False, | |
| node_size_factor: float = 1.0, | |
| figsize: Tuple[int, int] = (12, 8)) -> str: | |
| """Create a matplotlib visualization of the graph and return file path.""" | |
| if not graph.nodes(): | |
| return self._create_empty_graph_image() | |
| # Create figure | |
| plt.figure(figsize=figsize) | |
| plt.clf() | |
| # Calculate layout | |
| pos = self._calculate_layout(graph, layout_type) | |
| # Get node properties | |
| node_colors = [self.color_map.get(graph.nodes[node].get('type', 'UNKNOWN'), '#95A5A6') | |
| for node in graph.nodes()] | |
| node_sizes = [graph.nodes[node].get('size', 20) * node_size_factor * 10 | |
| for node in graph.nodes()] | |
| # Draw nodes | |
| nx.draw_networkx_nodes(graph, pos, | |
| node_color=node_colors, | |
| node_size=node_sizes, | |
| alpha=0.8) | |
| # Draw edges | |
| nx.draw_networkx_edges(graph, pos, | |
| edge_color='gray', | |
| arrows=True, | |
| arrowsize=20, | |
| alpha=0.6, | |
| width=1.5) | |
| # Draw labels | |
| if show_labels: | |
| # Create labels with importance scores | |
| labels = {} | |
| for node in graph.nodes(): | |
| importance = graph.nodes[node].get('importance', 0.0) | |
| labels[node] = f"{node}\n({importance:.2f})" | |
| nx.draw_networkx_labels(graph, pos, labels, font_size=8) | |
| # Draw edge labels | |
| if show_edge_labels: | |
| edge_labels = {(u, v): data.get('relationship', '') | |
| for u, v, data in graph.edges(data=True)} | |
| nx.draw_networkx_edge_labels(graph, pos, edge_labels, font_size=6) | |
| plt.title("Knowledge Graph", fontsize=16, fontweight='bold') | |
| plt.axis('off') | |
| plt.tight_layout() | |
| # Save to temporary file | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') | |
| plt.savefig(temp_file.name, format='png', dpi=150, bbox_inches='tight') | |
| plt.close() | |
| return temp_file.name | |
| def _calculate_layout(self, graph: nx.DiGraph, layout_type: str) -> Dict[str, Tuple[float, float]]: | |
| """Calculate node positions using specified layout algorithm.""" | |
| try: | |
| if layout_type == "spring": | |
| return nx.spring_layout(graph, k=1, iterations=50) | |
| elif layout_type == "circular": | |
| return nx.circular_layout(graph) | |
| elif layout_type == "shell": | |
| return nx.shell_layout(graph) | |
| elif layout_type == "kamada_kawai": | |
| return nx.kamada_kawai_layout(graph) | |
| elif layout_type == "random": | |
| return nx.random_layout(graph) | |
| else: | |
| return nx.spring_layout(graph, k=1, iterations=50) | |
| except: | |
| # Fallback to simple layout if algorithm fails | |
| return nx.spring_layout(graph, k=1, iterations=50) | |
| def _create_empty_graph_image(self) -> str: | |
| """Create an image for empty graph.""" | |
| plt.figure(figsize=(8, 6)) | |
| plt.text(0.5, 0.5, 'No graph data to display', | |
| horizontalalignment='center', verticalalignment='center', | |
| fontsize=16, transform=plt.gca().transAxes) | |
| plt.axis('off') | |
| # Save to temporary file | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') | |
| plt.savefig(temp_file.name, format='png', dpi=150, bbox_inches='tight') | |
| plt.close() | |
| return temp_file.name | |
| def create_interactive_html(self, graph: nx.DiGraph) -> str: | |
| """Create an interactive HTML visualization using vis.js.""" | |
| if not graph.nodes(): | |
| return "<div>No graph data to display</div>" | |
| # Convert graph to vis.js format | |
| nodes = [] | |
| edges = [] | |
| for node, data in graph.nodes(data=True): | |
| nodes.append({ | |
| "id": node, | |
| "label": node, | |
| "color": self.color_map.get(data.get('type', 'UNKNOWN'), '#95A5A6'), | |
| "size": data.get('size', 20), | |
| "title": f"Type: {data.get('type', 'UNKNOWN')}<br>" | |
| f"Importance: {data.get('importance', 0.0):.2f}<br>" | |
| f"Description: {data.get('description', 'N/A')}" | |
| }) | |
| for u, v, data in graph.edges(data=True): | |
| edges.append({ | |
| "from": u, | |
| "to": v, | |
| "label": data.get('relationship', ''), | |
| "title": data.get('description', ''), | |
| "arrows": {"to": {"enabled": True}} | |
| }) | |
| html_template = f""" | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <script src="https://unpkg.com/vis-network/standalone/umd/vis-network.min.js"></script> | |
| <style> | |
| #mynetworkid {{ | |
| width: 100%; | |
| height: 600px; | |
| border: 1px solid lightgray; | |
| }} | |
| </style> | |
| </head> | |
| <body> | |
| <div id="mynetworkid"></div> | |
| <script> | |
| var nodes = new vis.DataSet({json.dumps(nodes)}); | |
| var edges = new vis.DataSet({json.dumps(edges)}); | |
| var container = document.getElementById('mynetworkid'); | |
| var data = {{ | |
| nodes: nodes, | |
| edges: edges | |
| }}; | |
| var options = {{ | |
| nodes: {{ | |
| shape: 'dot', | |
| scaling: {{ | |
| min: 10, | |
| max: 30 | |
| }}, | |
| font: {{ | |
| size: 12, | |
| face: 'Tahoma' | |
| }} | |
| }}, | |
| edges: {{ | |
| font: {{align: 'middle'}}, | |
| color: {{color:'gray'}}, | |
| arrows: {{to: {{enabled: true, scaleFactor: 1}}}} | |
| }}, | |
| physics: {{ | |
| enabled: true, | |
| stabilization: {{enabled: true, iterations: 200}} | |
| }}, | |
| interaction: {{ | |
| hover: true, | |
| tooltipDelay: 200 | |
| }} | |
| }}; | |
| var network = new vis.Network(container, data, options); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| return html_template | |
| def create_statistics_summary(self, graph: nx.DiGraph, stats: Dict[str, Any]) -> str: | |
| """Create a formatted statistics summary.""" | |
| if not graph.nodes(): | |
| return "No graph statistics available." | |
| # Entity type distribution | |
| type_counts = {} | |
| for node, data in graph.nodes(data=True): | |
| node_type = data.get('type', 'UNKNOWN') | |
| type_counts[node_type] = type_counts.get(node_type, 0) + 1 | |
| # Relationship type distribution | |
| rel_counts = {} | |
| for u, v, data in graph.edges(data=True): | |
| rel_type = data.get('relationship', 'unknown') | |
| rel_counts[rel_type] = rel_counts.get(rel_type, 0) + 1 | |
| summary = f""" | |
| ## Graph Statistics | |
| **Basic Metrics:** | |
| - Nodes: {stats['num_nodes']} | |
| - Edges: {stats['num_edges']} | |
| - Density: {stats['density']:.3f} | |
| - Connected: {'Yes' if stats['is_connected'] else 'No'} | |
| - Components: {stats['num_components']} | |
| - Average Degree: {stats['avg_degree']:.2f} | |
| **Entity Types:** | |
| """ | |
| for entity_type, count in sorted(type_counts.items()): | |
| summary += f"\n- {entity_type}: {count}" | |
| summary += "\n\n**Relationship Types:**" | |
| for rel_type, count in sorted(rel_counts.items()): | |
| summary += f"\n- {rel_type}: {count}" | |
| return summary | |
| def create_entity_list(self, graph: nx.DiGraph, sort_by: str = "importance") -> str: | |
| """Create a formatted list of entities.""" | |
| if not graph.nodes(): | |
| return "No entities found." | |
| entities = [] | |
| for node, data in graph.nodes(data=True): | |
| entities.append({ | |
| 'name': node, | |
| 'type': data.get('type', 'UNKNOWN'), | |
| 'importance': data.get('importance', 0.0), | |
| 'description': data.get('description', 'N/A'), | |
| 'connections': graph.degree(node) | |
| }) | |
| # Sort entities | |
| if sort_by == "importance": | |
| entities.sort(key=lambda x: x['importance'], reverse=True) | |
| elif sort_by == "connections": | |
| entities.sort(key=lambda x: x['connections'], reverse=True) | |
| elif sort_by == "name": | |
| entities.sort(key=lambda x: x['name']) | |
| entity_list = "## Entities\n\n" | |
| for entity in entities: | |
| entity_list += f""" | |
| **{entity['name']}** ({entity['type']}) | |
| - Importance: {entity['importance']:.2f} | |
| - Connections: {entity['connections']} | |
| - Description: {entity['description']} | |
| """ | |
| return entity_list | |
| def get_layout_options(self) -> List[str]: | |
| """Get available layout options.""" | |
| return ["spring", "circular", "shell", "kamada_kawai", "random"] | |
| def get_entity_types(self, graph: nx.DiGraph) -> List[str]: | |
| """Get unique entity types from the graph.""" | |
| types = set() | |
| for node, data in graph.nodes(data=True): | |
| types.add(data.get('type', 'UNKNOWN')) | |
| return sorted(list(types)) | |
| def create_plotly_interactive(self, graph: nx.DiGraph, layout_type: str = "spring") -> go.Figure: | |
| """Create an interactive Plotly visualization of the graph.""" | |
| if not graph.nodes(): | |
| # Return empty figure | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text="No graph data to display", | |
| xref="paper", yref="paper", | |
| x=0.5, y=0.5, xanchor='center', yanchor='middle', | |
| showarrow=False, font=dict(size=16) | |
| ) | |
| return fig | |
| # Calculate layout | |
| pos = self._calculate_layout(graph, layout_type) | |
| # Prepare node data | |
| node_x = [] | |
| node_y = [] | |
| node_text = [] | |
| node_info = [] | |
| node_colors = [] | |
| node_sizes = [] | |
| for node in graph.nodes(): | |
| x, y = pos[node] | |
| node_x.append(x) | |
| node_y.append(y) | |
| data = graph.nodes[node] | |
| node_type = data.get('type', 'UNKNOWN') | |
| importance = data.get('importance', 0.0) | |
| description = data.get('description', 'N/A') | |
| connections = graph.degree(node) | |
| node_text.append(node) | |
| node_info.append( | |
| f"<b>{node}</b><br>" | |
| f"Type: {node_type}<br>" | |
| f"Importance: {importance:.2f}<br>" | |
| f"Connections: {connections}<br>" | |
| f"Description: {description}" | |
| ) | |
| node_colors.append(self.color_map.get(node_type, '#95A5A6')) | |
| node_sizes.append(max(10, data.get('size', 20))) | |
| # Prepare edge data | |
| edge_x = [] | |
| edge_y = [] | |
| edge_info = [] | |
| for edge in graph.edges(): | |
| x0, y0 = pos[edge[0]] | |
| x1, y1 = pos[edge[1]] | |
| edge_x.extend([x0, x1, None]) | |
| edge_y.extend([y0, y1, None]) | |
| edge_data = graph.edges[edge] | |
| relationship = edge_data.get('relationship', 'connected') | |
| edge_info.append(f"{edge[0]} → {edge[1]}<br>Relationship: {relationship}") | |
| # Create edge trace | |
| edge_trace = go.Scatter( | |
| x=edge_x, y=edge_y, | |
| line=dict(width=2, color='gray'), | |
| hoverinfo='none', | |
| mode='lines' | |
| ) | |
| # Create node trace | |
| node_trace = go.Scatter( | |
| x=node_x, y=node_y, | |
| mode='markers+text', | |
| hoverinfo='text', | |
| text=node_text, | |
| hovertext=node_info, | |
| textposition="middle center", | |
| marker=dict( | |
| size=node_sizes, | |
| color=node_colors, | |
| line=dict(width=2, color='white') | |
| ) | |
| ) | |
| # Create figure | |
| fig = go.Figure(data=[edge_trace, node_trace], | |
| layout=go.Layout( | |
| title='Interactive Knowledge Graph', | |
| titlefont_size=16, | |
| showlegend=False, | |
| hovermode='closest', | |
| margin=dict(b=20,l=5,r=5,t=40), | |
| annotations=[ dict( | |
| text="Hover over nodes for details. Drag to pan, scroll to zoom.", | |
| showarrow=False, | |
| xref="paper", yref="paper", | |
| x=0.005, y=-0.002, | |
| xanchor='left', yanchor='bottom', | |
| font=dict(color="gray", size=12) | |
| )], | |
| xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
| yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
| plot_bgcolor='white' | |
| )) | |
| return fig | |
| def create_pyvis_interactive(self, graph: nx.DiGraph, layout_type: str = "spring") -> str: | |
| """Create an interactive pyvis visualization and return HTML file path.""" | |
| if not graph.nodes(): | |
| return self._create_empty_pyvis_graph() | |
| # Create pyvis network | |
| net = Network(height="600px", width="100%", bgcolor="#ffffff", font_color="black") | |
| # Configure physics | |
| net.set_options(""" | |
| { | |
| "physics": { | |
| "enabled": true, | |
| "stabilization": {"enabled": true, "iterations": 200}, | |
| "barnesHut": { | |
| "gravitationalConstant": -2000, | |
| "centralGravity": 0.3, | |
| "springLength": 95, | |
| "springConstant": 0.04, | |
| "damping": 0.09 | |
| } | |
| }, | |
| "interaction": { | |
| "hover": true, | |
| "tooltipDelay": 200, | |
| "hideEdgesOnDrag": false | |
| } | |
| } | |
| """) | |
| # Add nodes | |
| for node, data in graph.nodes(data=True): | |
| node_type = data.get('type', 'UNKNOWN') | |
| importance = data.get('importance', 0.0) | |
| description = data.get('description', 'N/A') | |
| connections = graph.degree(node) | |
| # Node properties | |
| color = self.color_map.get(node_type, '#95A5A6') | |
| size = max(10, data.get('size', 20)) | |
| # Tooltip text | |
| title = f""" | |
| <b>{node}</b><br> | |
| Type: {node_type}<br> | |
| Importance: {importance:.2f}<br> | |
| Connections: {connections}<br> | |
| Description: {description} | |
| """ | |
| net.add_node(node, label=node, title=title, color=color, size=size) | |
| # Add edges | |
| for u, v, data in graph.edges(data=True): | |
| relationship = data.get('relationship', 'connected') | |
| title = f"{u} → {v}<br>Relationship: {relationship}" | |
| net.add_edge(u, v, title=title, arrows="to", color="gray") | |
| # Save to temporary file | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.html', mode='w') | |
| net.save_graph(temp_file.name) | |
| temp_file.close() | |
| return temp_file.name | |
| def _create_empty_pyvis_graph(self) -> str: | |
| """Create an empty pyvis graph.""" | |
| net = Network(height="600px", width="100%", bgcolor="#ffffff", font_color="black") | |
| net.add_node(1, label="No graph data", color="#cccccc") | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.html', mode='w') | |
| net.save_graph(temp_file.name) | |
| temp_file.close() | |
| return temp_file.name | |
| def get_visualization_options(self) -> List[str]: | |
| """Get available visualization types.""" | |
| return ["matplotlib", "plotly", "pyvis", "vis.js"] | |
| def get_relationship_types(self, graph: nx.DiGraph) -> List[str]: | |
| """Get unique relationship types from the graph.""" | |
| types = set() | |
| for u, v, data in graph.edges(data=True): | |
| types.add(data.get('relationship', 'unknown')) | |
| return sorted(list(types)) | |