Spaces:
Sleeping
Sleeping
File size: 50,070 Bytes
5b6c556 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 |
#!/usr/bin/env python3
# This script generates attribution graphs for the German OLMo2 7B model.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple, Optional, Any
import json
import logging
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM
from collections import defaultdict
import networkx as nx
from dataclasses import dataclass
from tqdm import tqdm
import pickle
import requests
import time
import random
import os
import argparse
# --- Add this block to fix the import path ---
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))
# ---------------------------------------------
from utils import init_qwen_api, set_seed
# --- Constants ---
# Configuration for the attribution graph generation pipeline.
RESULTS_DIR = "circuit_analysis/results"
CLT_SAVE_PATH = "circuit_analysis/models/clt_model_de.pth"
# Configure logging.
logging.basicConfig(level=logging.INFO, format='%(asctime=s - %(levelname=s - %(message=s')
logger = logging.getLogger(__name__)
# Set the device for training.
if torch.backends.mps.is_available():
DEVICE = torch.device("mps")
logger.info("Using MPS (Metal Performance Shaders) for GPU acceleration")
elif torch.cuda.is_available():
DEVICE = torch.device("cuda")
logger.info("Using CUDA for GPU acceleration")
else:
DEVICE = torch.device("cpu")
logger.info("Using CPU")
@dataclass
class AttributionGraphConfig:
# Configuration for building the attribution graph.
model_path: str = "./models/OLMo-2-1124-7B"
max_seq_length: int = 512
n_features_per_layer: int = 512
sparsity_lambda: float = 0.01
reconstruction_loss_weight: float = 1.0
batch_size: int = 8
learning_rate: float = 1e-4
training_steps: int = 1000
device: str = str(DEVICE)
pruning_threshold: float = 0.8 # For graph pruning
intervention_strength: float = 5.0 # For perturbation experiments
qwen_api_config: Optional[Dict[str, str]] = None
class JumpReLU(nn.Module):
# The JumpReLU activation function.
def __init__(self, threshold: float = 0.0):
super().__init__()
self.threshold = threshold
def forward(self, x):
return F.relu(x - self.threshold)
class CrossLayerTranscoder(nn.Module):
# The Cross-Layer Transcoder (CLT) model.
def __init__(self, model_config: Dict, clt_config: AttributionGraphConfig):
super().__init__()
self.config = clt_config
self.model_config = model_config
self.n_layers = model_config['num_hidden_layers']
self.hidden_size = model_config['hidden_size']
self.n_features = clt_config.n_features_per_layer
# Encoder weights for each layer.
self.encoders = nn.ModuleList([
nn.Linear(self.hidden_size, self.n_features, bias=False)
for _ in range(self.n_layers)
])
# Decoder weights for cross-layer connections.
self.decoders = nn.ModuleDict()
for source_layer in range(self.n_layers):
for target_layer in range(source_layer, self.n_layers):
key = f"{source_layer}_to_{target_layer}"
self.decoders[key] = nn.Linear(self.n_features, self.hidden_size, bias=False)
# The activation function.
self.activation = JumpReLU(threshold=0.0)
# Initialize the weights.
self._init_weights()
def _init_weights(self):
# Initializes the weights with small random values.
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.01)
def encode(self, layer_idx: int, residual_activations: torch.Tensor) -> torch.Tensor:
# Encodes residual stream activations to feature activations.
return self.activation(self.encoders[layer_idx](residual_activations))
def decode(self, source_layer: int, target_layer: int, feature_activations: torch.Tensor) -> torch.Tensor:
# Decodes feature activations to the MLP output space.
key = f"{source_layer}_to_{target_layer}"
return self.decoders[key](feature_activations)
def forward(self, residual_activations: List[torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
# The forward pass of the CLT.
feature_activations = []
reconstructed_mlp_outputs = []
# Encode features for each layer.
for layer_idx, residual in enumerate(residual_activations):
features = self.encode(layer_idx, residual)
feature_activations.append(features)
# Reconstruct MLP outputs with cross-layer connections.
for target_layer in range(self.n_layers):
reconstruction = torch.zeros_like(residual_activations[target_layer])
# Sum contributions from all previous layers.
for source_layer in range(target_layer + 1):
decoded = self.decode(source_layer, target_layer, feature_activations[source_layer])
reconstruction += decoded
reconstructed_mlp_outputs.append(reconstruction)
return feature_activations, reconstructed_mlp_outputs
class FeatureVisualizer:
# A class to visualize and interpret individual features.
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.feature_interpretations = {}
def visualize_feature(self, feature_idx: int, layer_idx: int,
activations: torch.Tensor, input_tokens: List[str],
top_k: int = 10) -> Dict:
# Creates a visualization for a single feature.
feature_acts = activations[:, feature_idx].detach().cpu().numpy()
# Find the top activating positions.
top_positions = np.argsort(feature_acts)[-top_k:][::-1]
visualization = {
'feature_idx': feature_idx,
'layer_idx': layer_idx,
'max_activation': float(feature_acts.max()),
'mean_activation': float(feature_acts.mean()),
'sparsity': float((feature_acts > 0.1).mean()),
'top_activations': []
}
for pos in top_positions:
if pos < len(input_tokens):
visualization['top_activations'].append({
'token': input_tokens[pos],
'position': int(pos),
'activation': float(feature_acts[pos])
})
return visualization
def interpret_feature(self, feature_idx: int, layer_idx: int,
visualization_data: Dict,
qwen_api_config: Optional[Dict[str, str]] = None) -> str:
# Interprets a feature based on its top activating tokens.
top_tokens = [item['token'] for item in visualization_data['top_activations']]
# Use the Qwen API if it is configured.
if qwen_api_config and qwen_api_config.get('api_key'):
feature_name = f"L{layer_idx}_F{feature_idx}"
interpretation = get_feature_interpretation_with_qwen(
qwen_api_config, top_tokens, feature_name, layer_idx
)
else:
# Use a simple heuristic as a fallback.
if len(set(top_tokens)) == 1:
interpretation = f"Spezifischer Token: '{top_tokens[0]}'"
elif all(token.isalpha() for token in top_tokens):
interpretation = "Wort/alphabetische Tokens"
elif all(token.isdigit() for token in top_tokens):
interpretation = "Numerische Tokens"
elif all(token in '.,!?;:' for token in top_tokens):
interpretation = "Interpunktion"
else:
interpretation = "Gemischte/polysemische Merkmale"
self.feature_interpretations[f"L{layer_idx}_F{feature_idx}"] = interpretation
return interpretation
class AttributionGraph:
# A class to construct and analyze attribution graphs.
def __init__(self, clt: CrossLayerTranscoder, tokenizer):
self.clt = clt
self.tokenizer = tokenizer
self.graph = nx.DiGraph()
self.node_types = {} # Track node types (feature, embedding, error, output)
self.edge_weights = {}
def compute_virtual_weights(self, source_layer: int, target_layer: int,
source_feature: int, target_feature: int) -> float:
# Computes the virtual weight between two features.
if target_layer <= source_layer:
return 0.0
# Get the encoder and decoder weights.
encoder_weight = self.clt.encoders[target_layer].weight[target_feature]
total_weight = 0.0
for intermediate_layer in range(source_layer, target_layer):
decoder_key = f"{source_layer}_to_{intermediate_layer}"
if decoder_key in self.clt.decoders:
decoder_weight = self.clt.decoders[decoder_key].weight[:, source_feature]
# The virtual weight is the inner product.
virtual_weight = torch.dot(decoder_weight, encoder_weight).item()
total_weight += virtual_weight
return total_weight
def construct_graph(self, input_tokens: List[str],
feature_activations: List[torch.Tensor],
target_token_idx: int = -1) -> nx.DiGraph:
# Constructs the attribution graph for a prompt.
self.graph.clear()
self.node_types.clear()
self.edge_weights.clear()
seq_len = len(input_tokens)
n_layers = len(feature_activations)
# Add embedding nodes for the input tokens.
for i, token in enumerate(input_tokens):
node_id = f"emb_{i}_{token}"
self.graph.add_node(node_id)
self.node_types[node_id] = "embedding"
# Add nodes for the features.
active_features = {}
max_features_per_layer = 20
for layer_idx, features in enumerate(feature_activations):
# features shape: [batch_size, seq_len, n_features]
batch_size, seq_len_layer, n_features = features.shape
# Get the top activating features for this layer.
layer_activations = features[0].mean(dim=0)
top_features = torch.topk(layer_activations,
k=min(max_features_per_layer, n_features)).indices
for token_pos in range(min(seq_len, seq_len_layer)):
for feat_idx in top_features:
activation = features[0, token_pos, feat_idx.item()].item()
if activation > 0.05:
node_id = f"feat_L{layer_idx}_T{token_pos}_F{feat_idx.item()}"
self.graph.add_node(node_id)
self.node_types[node_id] = "feature"
active_features[node_id] = {
'layer': layer_idx,
'token_pos': token_pos,
'feature_idx': feat_idx.item(),
'activation': activation
}
# Add an output node for the target token.
output_node = f"output_{target_token_idx}"
self.graph.add_node(output_node)
self.node_types[output_node] = "output"
# Add edges based on virtual weights and activations.
feature_nodes = [node for node, type_ in self.node_types.items() if type_ == "feature"]
print(f" Building attribution graph: {len(feature_nodes)} feature nodes, {len(self.graph.nodes())} total nodes")
# Limit the number of edges to compute.
max_edges_per_node = 5
for i, source_node in enumerate(feature_nodes):
if i % 50 == 0:
print(f" Processing node {i+1}/{len(feature_nodes)}")
edges_added = 0
source_info = active_features[source_node]
source_activation = source_info['activation']
# Add edges to other features.
for target_node in feature_nodes:
if source_node == target_node or edges_added >= max_edges_per_node:
continue
target_info = active_features[target_node]
# Only add edges that go forward in the network.
if (target_info['layer'] > source_info['layer'] or
(target_info['layer'] == source_info['layer'] and
target_info['token_pos'] > source_info['token_pos'])):
virtual_weight = self.compute_virtual_weights(
source_info['layer'], target_info['layer'],
source_info['feature_idx'], target_info['feature_idx']
)
if abs(virtual_weight) > 0.05:
edge_weight = source_activation * virtual_weight
self.graph.add_edge(source_node, target_node, weight=edge_weight)
self.edge_weights[(source_node, target_node)] = edge_weight
edges_added += 1
# Add edges to the output node.
if source_info['layer'] >= n_layers - 2:
output_weight = source_activation * 0.1
self.graph.add_edge(source_node, output_node, weight=output_weight)
self.edge_weights[(source_node, output_node)] = output_weight
# Add edges from embeddings to early features.
for emb_node in [node for node, type_ in self.node_types.items() if type_ == "embedding"]:
token_idx = int(emb_node.split('_')[1])
for feat_node in feature_nodes:
feat_info = active_features[feat_node]
if feat_info['layer'] == 0 and feat_info['token_pos'] == token_idx:
# Direct connection from an embedding to a first-layer feature.
weight = feat_info['activation'] * 0.5
self.graph.add_edge(emb_node, feat_node, weight=weight)
self.edge_weights[(emb_node, feat_node)] = weight
return self.graph
def prune_graph(self, threshold: float = 0.8) -> nx.DiGraph:
# Prunes the graph to keep only the most important nodes.
# Calculate node importance based on edge weights.
node_importance = defaultdict(float)
for (source, target), weight in self.edge_weights.items():
node_importance[source] += abs(weight)
node_importance[target] += abs(weight)
# Keep the top nodes by importance.
sorted_nodes = sorted(node_importance.items(), key=lambda x: x[1], reverse=True)
n_keep = int(len(sorted_nodes) * threshold)
important_nodes = set([node for node, _ in sorted_nodes[:n_keep]])
# Always keep the output and embedding nodes.
for node, type_ in self.node_types.items():
if type_ in ["output", "embedding"]:
important_nodes.add(node)
# Create the pruned graph.
pruned_graph = self.graph.subgraph(important_nodes).copy()
return pruned_graph
def visualize_graph(self, graph: nx.DiGraph = None, save_path: str = None):
# Visualizes the attribution graph.
if graph is None:
graph = self.graph
plt.figure(figsize=(12, 8))
# Create a layout for the graph.
pos = nx.spring_layout(graph, k=1, iterations=50)
# Color the nodes by type.
node_colors = []
for node in graph.nodes():
node_type = self.node_types.get(node, "unknown")
if node_type == "embedding":
node_colors.append('lightblue')
elif node_type == "feature":
node_colors.append('lightgreen')
elif node_type == "output":
node_colors.append('orange')
else:
node_colors.append('gray')
# Draw the nodes.
nx.draw_networkx_nodes(graph, pos, node_color=node_colors,
node_size=300, alpha=0.8)
# Draw the edges with thickness based on weight.
edges = graph.edges()
edge_weights = [abs(self.edge_weights.get((u, v), 0.1)) for u, v in edges]
max_weight = max(edge_weights) if edge_weights else 1
edge_widths = [w / max_weight * 3 for w in edge_weights]
nx.draw_networkx_edges(graph, pos, width=edge_widths, alpha=0.6,
edge_color='gray', arrows=True)
# Draw the labels.
nx.draw_networkx_labels(graph, pos, font_size=8)
plt.title("Attribution Graph (German)")
plt.axis('off')
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
class PerturbationExperiments:
# Conducts perturbation experiments to validate hypotheses.
def __init__(self, model, clt: CrossLayerTranscoder, tokenizer):
self.model = model
self.clt = clt
self.tokenizer = tokenizer
def feature_ablation_experiment(self, input_text: str,
target_layer: int, target_feature: int,
intervention_strength: float = 5.0) -> Dict:
# Ablates a feature and measures the effect on the model's output.
try:
# Clear the MPS cache to prevent memory issues.
if torch.backends.mps.is_available():
torch.mps.empty_cache()
# Tokenize the input.
inputs = self.tokenizer(input_text, return_tensors="pt", padding=True,
truncation=True, max_length=512)
# Move inputs to the correct device.
device = next(self.model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Get the baseline predictions.
with torch.no_grad():
baseline_outputs = self.model(**inputs)
baseline_logits = baseline_outputs.logits[0, -1, :]
baseline_probs = F.softmax(baseline_logits, dim=-1)
baseline_top_tokens = torch.topk(baseline_probs, k=5)
# TODO: Implement the actual feature intervention.
# Simulate the effect of the intervention.
intervention_effect = {
'baseline_top_tokens': [
(self.tokenizer.decode([idx]), prob.item())
for idx, prob in zip(baseline_top_tokens.indices, baseline_top_tokens.values)
],
'intervention_layer': target_layer,
'intervention_feature': target_feature,
'intervention_strength': intervention_strength,
'effect_magnitude': 0.1,
'probability_change': 0.05
}
return intervention_effect
except Exception as e:
# Handle MPS memory issues.
print(f" Warning: Perturbation experiment failed due to device issue: {e}")
return {
'baseline_top_tokens': [],
'intervention_layer': target_layer,
'intervention_feature': target_feature,
'intervention_strength': intervention_strength,
'effect_magnitude': 0.0,
'probability_change': 0.0,
'error': str(e)
}
class AttributionGraphsPipeline:
# The main pipeline for the attribution graph analysis.
def __init__(self, config: AttributionGraphConfig):
self.config = config
self.device = torch.device(config.device)
# Load the model and tokenizer.
logger.info(f"Loading OLMo2 7B model from {config.model_path}")
self.tokenizer = AutoTokenizer.from_pretrained(config.model_path)
# Configure model loading based on the device.
if "mps" in config.device:
# MPS supports float16 but not device_map.
self.model = AutoModelForCausalLM.from_pretrained(
config.model_path,
torch_dtype=torch.float16,
device_map=None
).to(self.device)
elif "cuda" in config.device:
self.model = AutoModelForCausalLM.from_pretrained(
config.model_path,
torch_dtype=torch.float16,
device_map="auto"
)
else:
# CPU
self.model = AutoModelForCausalLM.from_pretrained(
config.model_path,
torch_dtype=torch.float32,
device_map=None
).to(self.device)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Initialize the CLT.
model_config = self.model.config.to_dict()
self.clt = CrossLayerTranscoder(model_config, config).to(self.device)
# Initialize the other components.
self.feature_visualizer = FeatureVisualizer(self.tokenizer)
self.attribution_graph = AttributionGraph(self.clt, self.tokenizer)
self.perturbation_experiments = PerturbationExperiments(self.model, self.clt, self.tokenizer)
logger.info("Attribution Graphs Pipeline initialized successfully")
def train_clt(self, training_texts: List[str]) -> Dict:
# Trains the Cross-Layer Transcoder.
logger.info("Starting CLT training...")
optimizer = torch.optim.Adam(self.clt.parameters(), lr=self.config.learning_rate)
training_stats = {
'reconstruction_losses': [],
'sparsity_losses': [],
'total_losses': []
}
for step in tqdm(range(self.config.training_steps), desc="Training CLT"):
# Sample a batch of texts.
batch_texts = np.random.choice(training_texts, size=self.config.batch_size)
total_loss = 0.0
total_recon_loss = 0.0
total_sparsity_loss = 0.0
for text in batch_texts:
# Tokenize the text.
inputs = self.tokenizer(text, return_tensors="pt", max_length=self.config.max_seq_length,
truncation=True, padding=True).to(self.device)
# Get the model activations.
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states[1:]
# Forward pass through the CLT.
feature_activations, reconstructed_outputs = self.clt(hidden_states)
# Compute the reconstruction loss.
recon_loss = 0.0
for i, (target, pred) in enumerate(zip(hidden_states, reconstructed_outputs)):
recon_loss += F.mse_loss(pred, target)
# Compute the sparsity loss.
sparsity_loss = 0.0
for features in feature_activations:
sparsity_loss += torch.mean(torch.tanh(self.config.sparsity_lambda * features))
# Total loss.
loss = (self.config.reconstruction_loss_weight * recon_loss +
self.config.sparsity_lambda * sparsity_loss)
total_loss += loss
total_recon_loss += recon_loss
total_sparsity_loss += sparsity_loss
# Average the losses.
total_loss /= self.config.batch_size
total_recon_loss /= self.config.batch_size
total_sparsity_loss /= self.config.batch_size
# Backward pass.
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
# Log the progress.
training_stats['total_losses'].append(total_loss.item())
training_stats['reconstruction_losses'].append(total_recon_loss.item())
training_stats['sparsity_losses'].append(total_sparsity_loss.item())
if step % 100 == 0:
logger.info(f"Step {step}: Total Loss = {total_loss.item():.4f}, "
f"Recon Loss = {total_recon_loss.item():.4f}, "
f"Sparsity Loss = {total_sparsity_loss.item():.4f}")
logger.info("CLT training completed")
return training_stats
def analyze_prompt(self, prompt: str, target_token_idx: int = -1) -> Dict:
# Performs a complete analysis for a single prompt.
logger.info(f"Analyzing prompt: '{prompt[:50]}...'")
# Tokenize the prompt.
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
input_tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
# Get the model activations.
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states[1:]
# Forward pass through the CLT.
feature_activations, reconstructed_outputs = self.clt(hidden_states)
logger.info(" > Starting feature visualization and interpretation...")
feature_visualizations = {}
for layer_idx, features in enumerate(feature_activations):
logger.info(f" - Processing Layer {layer_idx}...")
layer_viz = {}
# Analyze the top features for this layer.
# features shape: [batch_size, seq_len, n_features]
feature_importance = torch.mean(features, dim=(0, 1))
top_features = torch.topk(feature_importance, k=min(5, feature_importance.size(0))).indices
for feat_idx in top_features:
viz = self.feature_visualizer.visualize_feature(
feat_idx.item(), layer_idx, features[0], input_tokens
)
interpretation = self.feature_visualizer.interpret_feature(
feat_idx.item(), layer_idx, viz, self.config.qwen_api_config
)
viz['interpretation'] = interpretation
layer_viz[f"feature_{feat_idx.item()}"] = viz
feature_visualizations[f"layer_{layer_idx}"] = layer_viz
# Construct the attribution graph.
graph = self.attribution_graph.construct_graph(
input_tokens, feature_activations, target_token_idx
)
# Prune the graph.
pruned_graph = self.attribution_graph.prune_graph(self.config.pruning_threshold)
# Analyze the most important paths.
important_paths = []
if len(pruned_graph.nodes()) > 0:
# Find paths from embeddings to the output.
embedding_nodes = [node for node, type_ in self.attribution_graph.node_types.items()
if type_ == "embedding" and node in pruned_graph]
output_nodes = [node for node, type_ in self.attribution_graph.node_types.items()
if type_ == "output" and node in pruned_graph]
for emb_node in embedding_nodes[:3]:
for out_node in output_nodes:
try:
paths = list(nx.all_simple_paths(pruned_graph, emb_node, out_node, cutoff=5))
for path in paths[:2]:
path_weight = 1.0
for i in range(len(path) - 1):
edge_weight = self.attribution_graph.edge_weights.get(
(path[i], path[i+1]), 0.0
)
path_weight *= abs(edge_weight)
important_paths.append({
'path': path,
'weight': path_weight,
'description': self._describe_path(path)
})
except nx.NetworkXNoPath:
continue
# Sort paths by importance.
important_paths.sort(key=lambda x: x['weight'], reverse=True)
results = {
'prompt': prompt,
'input_tokens': input_tokens,
'feature_visualizations': feature_visualizations,
'full_graph_stats': {
'n_nodes': len(graph.nodes()),
'n_edges': len(graph.edges()),
'node_types': dict(self.attribution_graph.node_types)
},
'pruned_graph_stats': {
'n_nodes': len(pruned_graph.nodes()),
'n_edges': len(pruned_graph.edges())
},
'important_paths': important_paths[:5],
'graph': pruned_graph
}
return results
def _describe_path(self, path: List[str]) -> str:
# Generates a human-readable description of a path.
descriptions = []
for node in path:
if self.attribution_graph.node_types[node] == "embedding":
token = node.split('_')[2]
descriptions.append(f"Token '{token}'")
elif self.attribution_graph.node_types[node] == "feature":
parts = node.split('_')
layer = parts[1][1:]
feature = parts[3][1:]
# Try to get the interpretation.
key = f"L{layer}_F{feature}"
interpretation = self.feature_visualizer.feature_interpretations.get(key, "unknown")
descriptions.append(f"Feature L{layer}F{feature} ({interpretation})")
elif self.attribution_graph.node_types[node] == "output":
descriptions.append("Output")
return " → ".join(descriptions)
def save_results(self, results: Dict, save_path: str):
# Saves the analysis results to a file.
# Convert the graph to a serializable format.
serializable_results = results.copy()
if 'graph' in serializable_results:
graph_data = nx.node_link_data(serializable_results['graph'])
serializable_results['graph'] = graph_data
with open(save_path, 'w', encoding='utf-8') as f:
json.dump(serializable_results, f, indent=2, default=str, ensure_ascii=False)
logger.info(f"Results saved to {save_path}")
def save_clt(self, path: str):
# Saves the trained CLT model.
torch.save(self.clt.state_dict(), path)
logger.info(f"CLT model saved to {path}")
def load_clt(self, path: str):
# Loads a trained CLT model.
self.clt.load_state_dict(torch.load(path, map_location=self.device))
self.clt.to(self.device)
self.clt.eval()
logger.info(f"Loaded CLT model from {path}")
# --- Configuration ---
MAX_SEQ_LEN = 256
N_FEATURES_PER_LAYER = 512
TRAINING_STEPS = 2500
BATCH_SIZE = 64
LEARNING_RATE = 1e-3
# German prompts for the final analysis.
ANALYSIS_PROMPTS = [
"Die Hauptstadt von Frankreich ist",
"def fakultaet(n):",
"Das literarische Stilmittel im Satz 'Der Wind flüsterte durch die Bäume' ist"
]
# A larger set of German prompts for training.
TRAINING_PROMPTS = [
"Die Hauptstadt von Frankreich ist", "Sein oder Nichtsein, das ist hier die Frage", "Was du heute kannst besorgen, das verschiebe nicht auf morgen",
"Der erste Mensch auf dem Mond war", "Die chemische Formel für Wasser ist H2O.",
"Übersetze ins Englische: 'Die Katze sitzt auf der Matte.'", "def fakultaet(n):", "import numpy as np",
"Die Hauptzutaten einer Pizza sind", "Was ist das Kraftwerk der Zelle?",
"Die Gleichung E=mc^2 beschreibt die Beziehung zwischen Energie und", "Setze die Geschichte fort: Es war einmal, da war ein",
"Klassifiziere das Sentiment: 'Ich bin überglücklich!'", "Extrahiere die Entitäten: 'Apple Inc. ist in Cupertino.'",
"Was ist die nächste Zahl: 2, 4, 8, 16, __?", "Ein rollender Stein setzt kein Moos an",
"Das Gegenteil von heiß ist", "import torch", "import pandas as pd", "class MeineKlasse:",
"def __init__(self):", "Die Primärfarben sind", "Was ist die Hauptstadt von Japan?",
"Wer hat 'Hamlet' geschrieben?", "Die Quadratwurzel von 64 ist", "Die Sonne geht im Osten auf",
"Der Pazifische Ozean ist der größte Ozean der Erde.", "Die Mitochondrien sind das Kraftwerk der Zelle.",
"Was ist die Hauptstadt der Mongolei?", "Der Film 'Matrix' kann folgendem Genre zugeordnet werden:",
"Die englische Übersetzung von 'Ich möchte bitte einen Kaffee bestellen.' lautet:",
"Das literarische Stilmittel im Satz 'Der Wind flüsterte durch die Bäume' ist",
"Eine Python-Funktion, die die Fakultät einer Zahl berechnet, lautet:",
"Die Hauptzutat eines Negroni-Cocktails ist",
"Fasse die Handlung von 'Hamlet' in einem Satz zusammen:",
"Der Satz 'Der Kuchen wurde vom Hund gefressen' steht in folgender Form:",
"Eine gute Überschrift für einen Artikel über einen neuen Durchbruch in der Batterietechnologie wäre:"
]
# --- Qwen API for Feature Interpretation ---
@torch.no_grad()
def get_feature_interpretation_with_qwen(
api_config: dict,
top_tokens: list[str],
feature_name: str,
layer_index: int,
max_retries: int = 3,
initial_backoff: float = 2.0
) -> str:
# Generates a high-quality interpretation for a feature using the Qwen API.
if not api_config or not api_config.get('api_key'):
logger.warning("Qwen API not configured. Skipping interpretation.")
return "API not configured"
headers = {
"Authorization": f"Bearer {api_config['api_key']}",
"Content-Type": "application/json"
}
# Create a specialized German prompt.
prompt_text = f"""
Sie sind ein Experte für die Interpretierbarkeit von Transformern. Ein Merkmal in einem Sprachmodell (Merkmal '{feature_name}' auf Schicht {layer_index}) wird am stärksten durch die folgenden Token aktiviert:
{', '.join(f"'{token}'" for token in top_tokens)}
Was ist, basierend *nur* auf diesen Token, die wahrscheinlichste Funktion oder Rolle dieses Merkmals?
Ihre Antwort muss ein kurzer, prägnanter Ausdruck sein (z.B. "Erkennen von Eigennamen", "Identifizieren von JSON-Syntax", "Vervollständigen von Listen", "Erkennen negativer Stimmung"). Schreiben Sie keinen ganzen Satz.
"""
data = {
"model": api_config["model"],
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": prompt_text}]
}
],
"max_tokens": 50,
"temperature": 0.1,
"top_p": 0.9,
"seed": 42
}
logger.info(f" > Interpreting {feature_name} (Layer {layer_index})...")
for attempt in range(max_retries):
try:
logger.info(f" - Attempt {attempt + 1}/{max_retries}: Sending request to Qwen API...")
response = requests.post(
f"{api_config['api_endpoint']}/chat/completions",
headers=headers,
json=data,
timeout=60
)
response.raise_for_status()
result = response.json()
interpretation = result["choices"][0]["message"]["content"].strip()
# Remove quotes from the output.
if interpretation.startswith('"') and interpretation.endswith('"'):
interpretation = interpretation[1:-1]
logger.info(f" - Success! Interpretation: '{interpretation}'")
return interpretation
except requests.exceptions.RequestException as e:
logger.warning(f" - Qwen API request failed (Attempt {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
backoff_time = initial_backoff * (2 ** attempt)
logger.info(f" - Retrying in {backoff_time:.1f} seconds...")
time.sleep(backoff_time)
else:
logger.error(" - Max retries reached. Failing.")
return f"API Error: {e}"
except (KeyError, IndexError) as e:
logger.error(f" - Failed to parse Qwen API response: {e}")
return "API Error: Invalid response format"
finally:
# Add a delay to respect API rate limits.
time.sleep(2.1)
return "API Error: Max retries exceeded"
def train_transcoder(transcoder, model, tokenizer, training_prompts, device, steps=1000, batch_size=16, optimizer=None):
# Trains the Cross-Layer Transcoder.
transcoder.train()
# Use a progress bar for visual feedback.
progress_bar = tqdm(range(steps), desc="Training CLT")
for step in progress_bar:
# Get a random batch of prompts.
batch_prompts = random.choices(training_prompts, k=batch_size)
# Tokenize the batch.
inputs = tokenizer(
batch_prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=MAX_SEQ_LEN
)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Get the model activations.
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states[1:]
# Forward pass through the CLT.
feature_activations, reconstructed_outputs = transcoder(hidden_states)
# Compute the reconstruction loss.
recon_loss = 0.0
for i, (target, pred) in enumerate(zip(hidden_states, reconstructed_outputs)):
recon_loss += F.mse_loss(pred, target)
# Compute the sparsity loss.
sparsity_loss = 0.0
for features in feature_activations:
sparsity_loss += torch.mean(torch.tanh(0.01 * features))
# Total loss.
loss = (0.8 * recon_loss + 0.2 * sparsity_loss)
if optimizer:
optimizer.zero_grad()
loss.backward()
optimizer.step()
progress_bar.set_postfix({
"Recon Loss": f"{recon_loss.item():.4f}",
"Sparsity Loss": f"{sparsity_loss.item():.4f}",
"Total Loss": f"{loss.item():.4f}"
})
def generate_feature_visualizations(transcoder, model, tokenizer, prompt, device, qwen_api_config=None, graph_config: Optional[AttributionGraphConfig] = None):
# Generates feature visualizations and interpretations for a prompt.
# Tokenize the prompt.
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=MAX_SEQ_LEN
)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Get the model activations.
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states[1:]
# Forward pass through the CLT.
feature_activations, reconstructed_outputs = transcoder(hidden_states)
# Visualize the features.
feature_visualizations = {}
for layer_idx, features in enumerate(feature_activations):
layer_viz = {}
# Analyze the top features for this layer.
# features shape: [batch_size, seq_len, n_features]
feature_importance = torch.mean(features, dim=(0, 1))
top_features = torch.topk(feature_importance, k=min(5, feature_importance.size(0))).indices
for feat_idx in top_features:
viz = FeatureVisualizer(tokenizer).visualize_feature(
feat_idx.item(), layer_idx, features[0], tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
)
interpretation = FeatureVisualizer(tokenizer).interpret_feature(
feat_idx.item(), layer_idx, viz, qwen_api_config
)
viz['interpretation'] = interpretation
layer_viz[f"feature_{feat_idx.item()}"] = viz
feature_visualizations[f"layer_{layer_idx}"] = layer_viz
# Construct the attribution graph.
if graph_config is None:
graph_config = AttributionGraphConfig()
attribution_graph = AttributionGraph(transcoder, tokenizer, graph_config)
graph = attribution_graph.construct_graph(
tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]), feature_activations, -1
)
# Prune the graph.
pruned_graph = attribution_graph.prune_graph(0.8)
# Analyze the most important paths.
important_paths = []
if len(pruned_graph.nodes()) > 0:
# Find paths from embeddings to the output.
embedding_nodes = [node for node, type_ in attribution_graph.node_types.items()
if type_ == "embedding" and node in pruned_graph]
output_nodes = [node for node, type_ in attribution_graph.node_types.items()
if type_ == "output" and node in pruned_graph]
for emb_node in embedding_nodes[:3]:
for out_node in output_nodes:
try:
paths = list(nx.all_simple_paths(pruned_graph, emb_node, out_node, cutoff=5))
for path in paths[:2]:
path_weight = 1.0
for i in range(len(path) - 1):
edge_weight = attribution_graph.edge_weights.get(
(path[i], path[i+1]), 0.0
)
path_weight *= abs(edge_weight)
important_paths.append({
'path': path,
'weight': path_weight,
'description': attribution_graph._describe_path(path)
})
except nx.NetworkXNoPath:
continue
# Sort paths by importance.
important_paths.sort(key=lambda x: x['weight'], reverse=True)
return {
"prompt": prompt,
"full_graph_stats": {
"n_nodes": len(graph.nodes()),
"n_edges": len(graph.edges()),
"node_types": dict(attribution_graph.node_types)
},
"pruned_graph_stats": {
"n_nodes": len(pruned_graph.nodes()),
"n_edges": len(pruned_graph.edges())
},
"feature_visualizations": feature_visualizations,
"important_paths": important_paths[:5]
}
def main():
# Main function to run the analysis for a single prompt.
# Set a seed for reproducibility.
set_seed()
# --- Argument Parser ---
parser = argparse.ArgumentParser(description="Run Attribution Graph analysis for a single prompt.")
parser.add_argument(
'--prompt-index',
type=int,
required=True,
help=f"The 0-based index of the prompt to analyze from the ANALYSIS_PROMPTS list (0 to {len(ANALYSIS_PROMPTS) - 1})."
)
parser.add_argument(
'--force-retrain-clt',
action='store_true',
help="Force re-training of the Cross-Layer Transcoder, even if a saved model exists."
)
args = parser.parse_args()
prompt_idx = args.prompt_index
if not (0 <= prompt_idx < len(ANALYSIS_PROMPTS)):
print(f"❌ Error: --prompt-index must be between 0 and {len(ANALYSIS_PROMPTS) - 1}.")
return
# Get the API config from the utility function.
qwen_api_config = init_qwen_api()
# Configuration
config = AttributionGraphConfig(
model_path="./models/OLMo-2-1124-7B",
n_features_per_layer=512,
training_steps=500,
batch_size=4,
max_seq_length=256,
learning_rate=1e-4,
sparsity_lambda=0.01,
qwen_api_config=qwen_api_config
)
print("Attribution Graphs for OLMo2 7B - Single Prompt Pipeline (German)")
print("=" * 50)
print(f"Model path: {config.model_path}")
print(f"Device: {config.device}")
try:
# Initialize the full pipeline.
print("🚀 Initializing Attribution Graphs Pipeline...")
pipeline = AttributionGraphsPipeline(config)
print("✓ Pipeline initialized successfully")
print()
# Load an existing CLT model or train a new one.
if os.path.exists(CLT_SAVE_PATH) and not args.force_retrain_clt:
print(f"🧠 Loading existing CLT model from {CLT_SAVE_PATH}...")
pipeline.load_clt(CLT_SAVE_PATH)
print("✓ CLT model loaded successfully.")
else:
if args.force_retrain_clt and os.path.exists(CLT_SAVE_PATH):
print("🏃♂️ --force-retrain-clt flag is set. Overwriting existing model.")
# Train a new CLT model.
print("📚 Training a new CLT model...")
print(f" Training on {len(TRAINING_PROMPTS)} example texts...")
training_stats = pipeline.train_clt(TRAINING_PROMPTS)
print("✓ CLT training completed.")
# Save the new model.
pipeline.save_clt(CLT_SAVE_PATH)
print(f" Saved trained model to {CLT_SAVE_PATH} for future use.")
print()
# Analyze the selected prompt.
prompt_to_analyze = ANALYSIS_PROMPTS[prompt_idx]
print(f"🔍 Analyzing prompt {prompt_idx + 1}/{len(ANALYSIS_PROMPTS)}: '{prompt_to_analyze}'")
analysis = pipeline.analyze_prompt(prompt_to_analyze, target_token_idx=-1)
# Display the key results.
print(f" ✓ Tokenized into {len(analysis['input_tokens'])} tokens")
print(f" ✓ Full graph: {analysis['full_graph_stats']['n_nodes']} nodes, {analysis['full_graph_stats']['n_edges']} edges")
print(f" ✓ Pruned graph: {analysis['pruned_graph_stats']['n_nodes']} nodes, {analysis['pruned_graph_stats']['n_edges']} edges")
# Show the top features.
print(" 📊 Top active features:")
for layer_name, layer_features in list(analysis['feature_visualizations'].items())[:3]:
print(f" {layer_name}:")
for feat_name, feat_data in list(layer_features.items())[:2]:
print(f" {feat_name}: {feat_data['interpretation']} (max: {feat_data['max_activation']:.3f})")
print()
# Run a perturbation experiment.
print("🧪 Running perturbation experiment...")
# (No need to pass training_stats to the experiment)
if analysis['feature_visualizations']:
first_layer_key = next(iter(analysis['feature_visualizations']), None)
if first_layer_key:
layer_idx = int(first_layer_key.split('_')[1])
first_feature_key = next(iter(analysis['feature_visualizations'][first_layer_key]), None)
if first_feature_key:
feature_idx = int(first_feature_key.split('_')[1])
ablation_result = pipeline.perturbation_experiments.feature_ablation_experiment(
prompt_to_analyze, layer_idx, feature_idx, intervention_strength=3.0
)
print(f" Ablated L{layer_idx}F{feature_idx}: Δ probability = {ablation_result['probability_change']:.4f}")
print("✓ Perturbation experiment completed")
print()
# Generate a visualization for the prompt.
print("📈 Generating visualization...")
if 'graph' in analysis and analysis['pruned_graph_stats']['n_nodes'] > 0:
viz_path = os.path.join(RESULTS_DIR, f"attribution_graph_prompt_de_{prompt_idx + 1}.png")
pipeline.attribution_graph.visualize_graph(analysis['graph'], save_path=viz_path)
print(f" ✓ Graph visualization saved to {viz_path}")
else:
print(" - Skipping visualization as no graph was generated or it was empty.")
# Save the results in a format for the web app.
save_path = os.path.join(RESULTS_DIR, f"attribution_graphs_results_de_prompt_{prompt_idx + 1}.json")
# Create a JSON file that can be merged with others.
final_results = {
"analyses": {
f"prompt_de_{prompt_idx + 1}": analysis
},
"config": config.__dict__,
"timestamp": str(time.time())
}
# The web page doesn't use the graph object, so remove it.
if 'graph' in final_results['analyses'][f"prompt_de_{prompt_idx + 1}"]:
del final_results['analyses'][f"prompt_de_{prompt_idx + 1}"]['graph']
pipeline.save_results(final_results, save_path)
print(f"💾 Results saved to {save_path}")
print("\n🎉 Analysis for this prompt completed successfully!")
except Exception as e:
print(f"❌ Error during execution: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main() |