Spaces:
Sleeping
Sleeping
| import json | |
| import requests | |
| from typing import List, Dict, Any, Optional | |
| from config.settings import Config | |
| class LLMExtractor: | |
| def __init__(self): | |
| self.config = Config() | |
| self.headers = { | |
| "Authorization": f"Bearer {self.config.OPENROUTER_API_KEY}", | |
| "Content-Type": "application/json" | |
| } | |
| def extract_entities_and_relationships(self, text: str) -> Dict[str, Any]: | |
| """Extract entities and relationships from text using LLM.""" | |
| prompt = self._create_extraction_prompt(text) | |
| try: | |
| response = self._call_openrouter_api(prompt, self.config.EXTRACTION_MODEL) | |
| result = self._parse_extraction_response(response) | |
| return result | |
| except Exception as e: | |
| # Try backup model | |
| try: | |
| response = self._call_openrouter_api(prompt, self.config.BACKUP_MODEL) | |
| result = self._parse_extraction_response(response) | |
| return result | |
| except Exception as backup_e: | |
| return { | |
| "entities": [], | |
| "relationships": [], | |
| "error": f"Primary: {str(e)}, Backup: {str(backup_e)}" | |
| } | |
| def _create_extraction_prompt(self, text: str) -> str: | |
| """Create prompt for entity and relationship extraction.""" | |
| return f""" | |
| You are an expert knowledge graph extraction system. Analyze the following text and extract: | |
| 1. ENTITIES: Important people, organizations, locations, concepts, events, objects, etc. | |
| 2. RELATIONSHIPS: How these entities relate to each other | |
| 3. IMPORTANCE SCORES: Rate each entity's importance from 0.0 to 1.0 based on how central it is to the text | |
| For each entity, provide: | |
| - name: The entity name (standardized/canonical form) | |
| - type: The entity type (PERSON, ORGANIZATION, LOCATION, CONCEPT, EVENT, OBJECT, etc.) | |
| - importance: Score from 0.0 to 1.0 | |
| - description: Brief description of the entity's role/significance | |
| For each relationship, provide: | |
| - source: Source entity name | |
| - target: Target entity name | |
| - relationship: Type of relationship (works_at, located_in, part_of, causes, etc.) | |
| - description: Brief description of the relationship | |
| Only respond with a valid JSON object with this structure and nothing else. Your response must be valid, parsable JSON!! | |
| === JSON STRUCTURE FOR RESPONSE / RESPONSE FORMAT === | |
| {{ | |
| "entities": [ | |
| {{ | |
| "name": "entity_name", | |
| "type": "ENTITY_TYPE", | |
| "importance": 0.8, | |
| "description": "Brief description" | |
| }} | |
| ], | |
| "relationships": [ | |
| {{ | |
| "source": "entity1", | |
| "target": "entity2", | |
| "relationship": "relationship_type", | |
| "description": "Brief description" | |
| }} | |
| ] | |
| }} | |
| === END OF JSON STRUCTURE FOR RESPONSE / END OF RESPONSE FORMAT === | |
| TEXT TO ANALYZE: | |
| {text} | |
| Reply in valid json using the format above! | |
| JSON OUTPUT: | |
| """ | |
| def _call_openrouter_api(self, prompt: str, model: str) -> str: | |
| """Make API call to OpenRouter.""" | |
| if not self.config.OPENROUTER_API_KEY: | |
| raise ValueError("OpenRouter API key not configured") | |
| payload = { | |
| "model": model, | |
| "messages": [ | |
| { | |
| "role": "user", | |
| "content": prompt | |
| } | |
| ], | |
| "max_tokens": 2048, | |
| "temperature": 0.1 | |
| } | |
| response = requests.post( | |
| f"{self.config.OPENROUTER_BASE_URL}/chat/completions", | |
| headers=self.headers, | |
| json=payload, | |
| timeout=60 | |
| ) | |
| if response.status_code != 200: | |
| raise Exception(f"API call failed: {response.status_code} - {response.text}") | |
| result = response.json() | |
| if "choices" not in result or not result["choices"]: | |
| raise Exception("Invalid API response format") | |
| return result["choices"][0]["message"]["content"] | |
| def _parse_extraction_response(self, response: str) -> Dict[str, Any]: | |
| """Parse the LLM response into structured data.""" | |
| try: | |
| # Try to find JSON in the response | |
| start_idx = response.find("{") | |
| end_idx = response.rfind("}") + 1 | |
| if start_idx == -1 or end_idx == 0: | |
| raise ValueError("No JSON found in response") | |
| json_str = response[start_idx:end_idx] | |
| data = json.loads(json_str) | |
| # Validate structure | |
| if "entities" not in data: | |
| data["entities"] = [] | |
| if "relationships" not in data: | |
| data["relationships"] = [] | |
| # Filter entities by importance threshold | |
| filtered_entities = [ | |
| entity for entity in data["entities"] | |
| if entity.get("importance", 0) >= self.config.ENTITY_IMPORTANCE_THRESHOLD | |
| ] | |
| # Limit number of entities and relationships | |
| data["entities"] = filtered_entities[:self.config.MAX_ENTITIES] | |
| data["relationships"] = data["relationships"][:self.config.MAX_RELATIONSHIPS] | |
| return data | |
| except json.JSONDecodeError as e: | |
| return { | |
| "entities": [], | |
| "relationships": [], | |
| "error": f"JSON parsing error: {str(e)}" | |
| } | |
| except Exception as e: | |
| return { | |
| "entities": [], | |
| "relationships": [], | |
| "error": f"Response parsing error: {str(e)}" | |
| } | |
| def process_chunks(self, chunks: List[str]) -> Dict[str, Any]: | |
| """Process multiple text chunks and combine results.""" | |
| all_entities = [] | |
| all_relationships = [] | |
| errors = [] | |
| for i, chunk in enumerate(chunks): | |
| try: | |
| result = self.extract_entities_and_relationships(chunk) | |
| if "error" in result: | |
| errors.append(f"Chunk {i+1}: {result['error']}") | |
| continue | |
| all_entities.extend(result.get("entities", [])) | |
| all_relationships.extend(result.get("relationships", [])) | |
| except Exception as e: | |
| errors.append(f"Chunk {i+1}: {str(e)}") | |
| # Deduplicate and standardize entities | |
| unique_entities = self._deduplicate_entities(all_entities) | |
| # Validate relationships against existing entities | |
| valid_relationships = self._validate_relationships(all_relationships, unique_entities) | |
| return { | |
| "entities": unique_entities, | |
| "relationships": valid_relationships, | |
| "errors": errors if errors else None | |
| } | |
| def _deduplicate_entities(self, entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| """Remove duplicate entities and merge similar ones.""" | |
| seen_names = set() | |
| unique_entities = [] | |
| for entity in entities: | |
| name = entity.get("name", "").lower().strip() | |
| if name and name not in seen_names: | |
| seen_names.add(name) | |
| unique_entities.append(entity) | |
| # Sort by importance | |
| unique_entities.sort(key=lambda x: x.get("importance", 0), reverse=True) | |
| return unique_entities[:self.config.MAX_ENTITIES] | |
| def _validate_relationships(self, relationships: List[Dict[str, Any]], entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| """Validate that relationships reference existing entities.""" | |
| entity_names = {entity.get("name", "").lower() for entity in entities} | |
| valid_relationships = [] | |
| for rel in relationships: | |
| source = rel.get("source", "").lower() | |
| target = rel.get("target", "").lower() | |
| if source in entity_names and target in entity_names: | |
| valid_relationships.append(rel) | |
| return valid_relationships[:self.config.MAX_RELATIONSHIPS] | |