Spaces:
Sleeping
Sleeping
File size: 8,752 Bytes
5b6c556 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
import os
import json
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from tqdm import tqdm as tqdm_iterator
import sys
import torch
# Configuration for the script.
DOLMA_DIR = os.path.join("influence_tracer", "dolma_dataset_sample_1.6v")
INDEX_DIR = os.path.join("influence_tracer", "influence_tracer_data")
INDEX_PATH = os.path.join(INDEX_DIR, "dolma_index_multi.faiss")
MAPPING_PATH = os.path.join(INDEX_DIR, "dolma_mapping_multi.json")
STATE_PATH = os.path.join(INDEX_DIR, "index_build_state_multi.json")
MODEL_NAME = 'paraphrase-multilingual-mpnet-base-v2'
# Performance tuning.
BATCH_SIZE = 131072
SAVE_INTERVAL = 10
def build_index():
# Scans the Dolma dataset, creates vector embeddings, and builds a FAISS index.
print("--- Starting Influence Tracer Index Build (Optimized for Speed) ---")
if not os.path.exists(DOLMA_DIR):
print(f"Error: Dolma directory not found at '{DOLMA_DIR}'")
print("Please ensure the dolma_dataset_sample_1.6v directory is in your project root.")
sys.exit(1)
os.makedirs(INDEX_DIR, exist_ok=True)
# Load or initialize the state to allow resuming.
processed_files = []
doc_id_counter = 0
total_docs_processed = 0
doc_mapping = {}
if os.path.exists(STATE_PATH):
print("Found existing state. Attempting to resume...")
try:
with open(STATE_PATH, 'r', encoding='utf-8') as f:
state = json.load(f)
processed_files = state.get('processed_files', [])
doc_id_counter = state.get('doc_id_counter', 0)
total_docs_processed = state.get('total_docs_processed', 0)
with open(MAPPING_PATH, 'r', encoding='utf-8') as f:
doc_mapping = json.load(f)
print(f"Reading existing index from {INDEX_PATH}...")
index = faiss.read_index(INDEX_PATH)
print(f"Resumed from state: {len(processed_files)} files processed, {total_docs_processed} documents indexed.")
except (IOError, json.JSONDecodeError, RuntimeError) as e:
print(f"Error resuming from state: {e}. Starting fresh.")
processed_files = []
doc_id_counter = 0
total_docs_processed = 0
doc_mapping = {}
index = None # Will be re-initialized
else:
print("No existing state found. Starting fresh.")
index = None
# Detect the best device to use (MPS, CUDA, or CPU).
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device.upper()}")
# Load the sentence transformer model.
print(f"Loading sentence transformer model: '{MODEL_NAME}'...")
try:
model = SentenceTransformer(MODEL_NAME, device=device)
except Exception as e:
print(f"Error loading model: {e}")
print("Please ensure you have an internet connection and the required libraries are installed.")
print("Try running: pip install sentence-transformers faiss-cpu numpy tqdm")
sys.exit(1)
print("Model loaded successfully.")
# Initialize the FAISS index if it wasn't loaded.
if index is None:
embedding_dim = model.get_sentence_embedding_dimension()
# Use Inner Product for cosine similarity.
index = faiss.IndexFlatIP(embedding_dim)
print(f"FAISS index initialized with dimension {embedding_dim} using Inner Product (IP) for similarity.")
# Get a list of all files to process.
print(f"Scanning for documents in '{DOLMA_DIR}'...")
all_files = sorted([os.path.join(DOLMA_DIR, f) for f in os.listdir(DOLMA_DIR) if f.endswith('.json')])
files_to_process = [f for f in all_files if os.path.basename(f) not in processed_files]
if not files_to_process:
if processed_files:
print("✅ All files have been processed. Index is up to date.")
print("--- Index Build Complete ---")
return
else:
print(f"Error: No JSON files found in '{DOLMA_DIR}'.")
sys.exit(1)
print(f"Found {len(all_files)} total files, {len(files_to_process)} remaining to process.")
# Process each file.
print(f"Processing remaining files with batch size {BATCH_SIZE}...")
files_processed_since_save = 0
for file_idx, path in enumerate(tqdm_iterator(files_to_process, desc="Processing files")):
texts_batch = []
batch_doc_info = []
try:
with open(path, 'r', encoding='utf-8') as f:
for line in f:
try:
data = json.loads(line)
text = data.get('text', '')
if text:
texts_batch.append(text)
batch_doc_info.append({
'id': doc_id_counter,
'info': {
'source': data.get('source', 'Unknown'),
'file': os.path.basename(path),
'text_snippet': text[:200] + '...'
}
})
doc_id_counter += 1
# Process the batch when it's full.
if len(texts_batch) >= BATCH_SIZE:
embeddings = model.encode(texts_batch, show_progress_bar=False, convert_to_numpy=True, normalize_embeddings=True)
index.add(embeddings.astype('float32'))
# Update the document mapping.
for doc in batch_doc_info:
doc_mapping[str(doc['id'])] = doc['info']
total_docs_processed += len(texts_batch)
texts_batch = []
batch_doc_info = []
except json.JSONDecodeError:
continue
# Process any remaining documents in the last batch.
if texts_batch:
embeddings = model.encode(texts_batch, show_progress_bar=False, convert_to_numpy=True, normalize_embeddings=True)
index.add(embeddings.astype('float32'))
# Update the mapping for the final batch.
for doc in batch_doc_info:
doc_mapping[str(doc['id'])] = doc['info']
total_docs_processed += len(texts_batch)
# Save progress periodically.
processed_files.append(os.path.basename(path))
files_processed_since_save += 1
if files_processed_since_save >= SAVE_INTERVAL or file_idx == len(files_to_process) - 1:
print(f"\nSaving progress ({total_docs_processed} docs processed)...")
faiss.write_index(index, INDEX_PATH)
with open(MAPPING_PATH, 'w', encoding='utf-8') as f:
json.dump(doc_mapping, f)
current_state = {
'processed_files': processed_files,
'doc_id_counter': doc_id_counter,
'total_docs_processed': total_docs_processed
}
with open(STATE_PATH, 'w', encoding='utf-8') as f:
json.dump(current_state, f)
files_processed_since_save = 0
print("Progress saved.")
except (IOError) as e:
print(f"Warning: Could not read or parse {path}. Skipping. Error: {e}")
continue
if index.ntotal == 0:
print("Error: No text could be extracted from the documents. Cannot build index.")
sys.exit(1)
print(f"\n🎉 Total documents processed: {total_docs_processed}")
print(f"✅ --- Index Build Complete ---")
print(f"Created index for {index.ntotal} documents.")
if __name__ == "__main__":
# This allows the script to be run from the command line.
print("This script will build a searchable index from your Dolma dataset.")
print("It needs to download a model and process all documents, so it may take some time.")
# Check for required libraries.
try:
import sentence_transformers
import faiss
import numpy
import tqdm
except ImportError:
print("\n--- Missing Required Libraries ---")
print("To run this script, please install the necessary packages by running:")
print("pip install sentence-transformers faiss-cpu numpy tqdm")
print("---------------------------------\n")
sys.exit(1)
build_index() |