ELIA / circuit_analysis /attribution_graphs_olmo_de.py
aaron0eidt's picture
Deploy static demo
5b6c556
#!/usr/bin/env python3
# This script generates attribution graphs for the German OLMo2 7B model.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple, Optional, Any
import json
import logging
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM
from collections import defaultdict
import networkx as nx
from dataclasses import dataclass
from tqdm import tqdm
import pickle
import requests
import time
import random
import os
import argparse
# --- Add this block to fix the import path ---
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))
# ---------------------------------------------
from utils import init_qwen_api, set_seed
# --- Constants ---
# Configuration for the attribution graph generation pipeline.
RESULTS_DIR = "circuit_analysis/results"
CLT_SAVE_PATH = "circuit_analysis/models/clt_model_de.pth"
# Configure logging.
logging.basicConfig(level=logging.INFO, format='%(asctime=s - %(levelname=s - %(message=s')
logger = logging.getLogger(__name__)
# Set the device for training.
if torch.backends.mps.is_available():
DEVICE = torch.device("mps")
logger.info("Using MPS (Metal Performance Shaders) for GPU acceleration")
elif torch.cuda.is_available():
DEVICE = torch.device("cuda")
logger.info("Using CUDA for GPU acceleration")
else:
DEVICE = torch.device("cpu")
logger.info("Using CPU")
@dataclass
class AttributionGraphConfig:
# Configuration for building the attribution graph.
model_path: str = "./models/OLMo-2-1124-7B"
max_seq_length: int = 512
n_features_per_layer: int = 512
sparsity_lambda: float = 0.01
reconstruction_loss_weight: float = 1.0
batch_size: int = 8
learning_rate: float = 1e-4
training_steps: int = 1000
device: str = str(DEVICE)
pruning_threshold: float = 0.8 # For graph pruning
intervention_strength: float = 5.0 # For perturbation experiments
qwen_api_config: Optional[Dict[str, str]] = None
class JumpReLU(nn.Module):
# The JumpReLU activation function.
def __init__(self, threshold: float = 0.0):
super().__init__()
self.threshold = threshold
def forward(self, x):
return F.relu(x - self.threshold)
class CrossLayerTranscoder(nn.Module):
# The Cross-Layer Transcoder (CLT) model.
def __init__(self, model_config: Dict, clt_config: AttributionGraphConfig):
super().__init__()
self.config = clt_config
self.model_config = model_config
self.n_layers = model_config['num_hidden_layers']
self.hidden_size = model_config['hidden_size']
self.n_features = clt_config.n_features_per_layer
# Encoder weights for each layer.
self.encoders = nn.ModuleList([
nn.Linear(self.hidden_size, self.n_features, bias=False)
for _ in range(self.n_layers)
])
# Decoder weights for cross-layer connections.
self.decoders = nn.ModuleDict()
for source_layer in range(self.n_layers):
for target_layer in range(source_layer, self.n_layers):
key = f"{source_layer}_to_{target_layer}"
self.decoders[key] = nn.Linear(self.n_features, self.hidden_size, bias=False)
# The activation function.
self.activation = JumpReLU(threshold=0.0)
# Initialize the weights.
self._init_weights()
def _init_weights(self):
# Initializes the weights with small random values.
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.01)
def encode(self, layer_idx: int, residual_activations: torch.Tensor) -> torch.Tensor:
# Encodes residual stream activations to feature activations.
return self.activation(self.encoders[layer_idx](residual_activations))
def decode(self, source_layer: int, target_layer: int, feature_activations: torch.Tensor) -> torch.Tensor:
# Decodes feature activations to the MLP output space.
key = f"{source_layer}_to_{target_layer}"
return self.decoders[key](feature_activations)
def forward(self, residual_activations: List[torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
# The forward pass of the CLT.
feature_activations = []
reconstructed_mlp_outputs = []
# Encode features for each layer.
for layer_idx, residual in enumerate(residual_activations):
features = self.encode(layer_idx, residual)
feature_activations.append(features)
# Reconstruct MLP outputs with cross-layer connections.
for target_layer in range(self.n_layers):
reconstruction = torch.zeros_like(residual_activations[target_layer])
# Sum contributions from all previous layers.
for source_layer in range(target_layer + 1):
decoded = self.decode(source_layer, target_layer, feature_activations[source_layer])
reconstruction += decoded
reconstructed_mlp_outputs.append(reconstruction)
return feature_activations, reconstructed_mlp_outputs
class FeatureVisualizer:
# A class to visualize and interpret individual features.
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.feature_interpretations = {}
def visualize_feature(self, feature_idx: int, layer_idx: int,
activations: torch.Tensor, input_tokens: List[str],
top_k: int = 10) -> Dict:
# Creates a visualization for a single feature.
feature_acts = activations[:, feature_idx].detach().cpu().numpy()
# Find the top activating positions.
top_positions = np.argsort(feature_acts)[-top_k:][::-1]
visualization = {
'feature_idx': feature_idx,
'layer_idx': layer_idx,
'max_activation': float(feature_acts.max()),
'mean_activation': float(feature_acts.mean()),
'sparsity': float((feature_acts > 0.1).mean()),
'top_activations': []
}
for pos in top_positions:
if pos < len(input_tokens):
visualization['top_activations'].append({
'token': input_tokens[pos],
'position': int(pos),
'activation': float(feature_acts[pos])
})
return visualization
def interpret_feature(self, feature_idx: int, layer_idx: int,
visualization_data: Dict,
qwen_api_config: Optional[Dict[str, str]] = None) -> str:
# Interprets a feature based on its top activating tokens.
top_tokens = [item['token'] for item in visualization_data['top_activations']]
# Use the Qwen API if it is configured.
if qwen_api_config and qwen_api_config.get('api_key'):
feature_name = f"L{layer_idx}_F{feature_idx}"
interpretation = get_feature_interpretation_with_qwen(
qwen_api_config, top_tokens, feature_name, layer_idx
)
else:
# Use a simple heuristic as a fallback.
if len(set(top_tokens)) == 1:
interpretation = f"Spezifischer Token: '{top_tokens[0]}'"
elif all(token.isalpha() for token in top_tokens):
interpretation = "Wort/alphabetische Tokens"
elif all(token.isdigit() for token in top_tokens):
interpretation = "Numerische Tokens"
elif all(token in '.,!?;:' for token in top_tokens):
interpretation = "Interpunktion"
else:
interpretation = "Gemischte/polysemische Merkmale"
self.feature_interpretations[f"L{layer_idx}_F{feature_idx}"] = interpretation
return interpretation
class AttributionGraph:
# A class to construct and analyze attribution graphs.
def __init__(self, clt: CrossLayerTranscoder, tokenizer):
self.clt = clt
self.tokenizer = tokenizer
self.graph = nx.DiGraph()
self.node_types = {} # Track node types (feature, embedding, error, output)
self.edge_weights = {}
def compute_virtual_weights(self, source_layer: int, target_layer: int,
source_feature: int, target_feature: int) -> float:
# Computes the virtual weight between two features.
if target_layer <= source_layer:
return 0.0
# Get the encoder and decoder weights.
encoder_weight = self.clt.encoders[target_layer].weight[target_feature]
total_weight = 0.0
for intermediate_layer in range(source_layer, target_layer):
decoder_key = f"{source_layer}_to_{intermediate_layer}"
if decoder_key in self.clt.decoders:
decoder_weight = self.clt.decoders[decoder_key].weight[:, source_feature]
# The virtual weight is the inner product.
virtual_weight = torch.dot(decoder_weight, encoder_weight).item()
total_weight += virtual_weight
return total_weight
def construct_graph(self, input_tokens: List[str],
feature_activations: List[torch.Tensor],
target_token_idx: int = -1) -> nx.DiGraph:
# Constructs the attribution graph for a prompt.
self.graph.clear()
self.node_types.clear()
self.edge_weights.clear()
seq_len = len(input_tokens)
n_layers = len(feature_activations)
# Add embedding nodes for the input tokens.
for i, token in enumerate(input_tokens):
node_id = f"emb_{i}_{token}"
self.graph.add_node(node_id)
self.node_types[node_id] = "embedding"
# Add nodes for the features.
active_features = {}
max_features_per_layer = 20
for layer_idx, features in enumerate(feature_activations):
# features shape: [batch_size, seq_len, n_features]
batch_size, seq_len_layer, n_features = features.shape
# Get the top activating features for this layer.
layer_activations = features[0].mean(dim=0)
top_features = torch.topk(layer_activations,
k=min(max_features_per_layer, n_features)).indices
for token_pos in range(min(seq_len, seq_len_layer)):
for feat_idx in top_features:
activation = features[0, token_pos, feat_idx.item()].item()
if activation > 0.05:
node_id = f"feat_L{layer_idx}_T{token_pos}_F{feat_idx.item()}"
self.graph.add_node(node_id)
self.node_types[node_id] = "feature"
active_features[node_id] = {
'layer': layer_idx,
'token_pos': token_pos,
'feature_idx': feat_idx.item(),
'activation': activation
}
# Add an output node for the target token.
output_node = f"output_{target_token_idx}"
self.graph.add_node(output_node)
self.node_types[output_node] = "output"
# Add edges based on virtual weights and activations.
feature_nodes = [node for node, type_ in self.node_types.items() if type_ == "feature"]
print(f" Building attribution graph: {len(feature_nodes)} feature nodes, {len(self.graph.nodes())} total nodes")
# Limit the number of edges to compute.
max_edges_per_node = 5
for i, source_node in enumerate(feature_nodes):
if i % 50 == 0:
print(f" Processing node {i+1}/{len(feature_nodes)}")
edges_added = 0
source_info = active_features[source_node]
source_activation = source_info['activation']
# Add edges to other features.
for target_node in feature_nodes:
if source_node == target_node or edges_added >= max_edges_per_node:
continue
target_info = active_features[target_node]
# Only add edges that go forward in the network.
if (target_info['layer'] > source_info['layer'] or
(target_info['layer'] == source_info['layer'] and
target_info['token_pos'] > source_info['token_pos'])):
virtual_weight = self.compute_virtual_weights(
source_info['layer'], target_info['layer'],
source_info['feature_idx'], target_info['feature_idx']
)
if abs(virtual_weight) > 0.05:
edge_weight = source_activation * virtual_weight
self.graph.add_edge(source_node, target_node, weight=edge_weight)
self.edge_weights[(source_node, target_node)] = edge_weight
edges_added += 1
# Add edges to the output node.
if source_info['layer'] >= n_layers - 2:
output_weight = source_activation * 0.1
self.graph.add_edge(source_node, output_node, weight=output_weight)
self.edge_weights[(source_node, output_node)] = output_weight
# Add edges from embeddings to early features.
for emb_node in [node for node, type_ in self.node_types.items() if type_ == "embedding"]:
token_idx = int(emb_node.split('_')[1])
for feat_node in feature_nodes:
feat_info = active_features[feat_node]
if feat_info['layer'] == 0 and feat_info['token_pos'] == token_idx:
# Direct connection from an embedding to a first-layer feature.
weight = feat_info['activation'] * 0.5
self.graph.add_edge(emb_node, feat_node, weight=weight)
self.edge_weights[(emb_node, feat_node)] = weight
return self.graph
def prune_graph(self, threshold: float = 0.8) -> nx.DiGraph:
# Prunes the graph to keep only the most important nodes.
# Calculate node importance based on edge weights.
node_importance = defaultdict(float)
for (source, target), weight in self.edge_weights.items():
node_importance[source] += abs(weight)
node_importance[target] += abs(weight)
# Keep the top nodes by importance.
sorted_nodes = sorted(node_importance.items(), key=lambda x: x[1], reverse=True)
n_keep = int(len(sorted_nodes) * threshold)
important_nodes = set([node for node, _ in sorted_nodes[:n_keep]])
# Always keep the output and embedding nodes.
for node, type_ in self.node_types.items():
if type_ in ["output", "embedding"]:
important_nodes.add(node)
# Create the pruned graph.
pruned_graph = self.graph.subgraph(important_nodes).copy()
return pruned_graph
def visualize_graph(self, graph: nx.DiGraph = None, save_path: str = None):
# Visualizes the attribution graph.
if graph is None:
graph = self.graph
plt.figure(figsize=(12, 8))
# Create a layout for the graph.
pos = nx.spring_layout(graph, k=1, iterations=50)
# Color the nodes by type.
node_colors = []
for node in graph.nodes():
node_type = self.node_types.get(node, "unknown")
if node_type == "embedding":
node_colors.append('lightblue')
elif node_type == "feature":
node_colors.append('lightgreen')
elif node_type == "output":
node_colors.append('orange')
else:
node_colors.append('gray')
# Draw the nodes.
nx.draw_networkx_nodes(graph, pos, node_color=node_colors,
node_size=300, alpha=0.8)
# Draw the edges with thickness based on weight.
edges = graph.edges()
edge_weights = [abs(self.edge_weights.get((u, v), 0.1)) for u, v in edges]
max_weight = max(edge_weights) if edge_weights else 1
edge_widths = [w / max_weight * 3 for w in edge_weights]
nx.draw_networkx_edges(graph, pos, width=edge_widths, alpha=0.6,
edge_color='gray', arrows=True)
# Draw the labels.
nx.draw_networkx_labels(graph, pos, font_size=8)
plt.title("Attribution Graph (German)")
plt.axis('off')
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
class PerturbationExperiments:
# Conducts perturbation experiments to validate hypotheses.
def __init__(self, model, clt: CrossLayerTranscoder, tokenizer):
self.model = model
self.clt = clt
self.tokenizer = tokenizer
def feature_ablation_experiment(self, input_text: str,
target_layer: int, target_feature: int,
intervention_strength: float = 5.0) -> Dict:
# Ablates a feature and measures the effect on the model's output.
try:
# Clear the MPS cache to prevent memory issues.
if torch.backends.mps.is_available():
torch.mps.empty_cache()
# Tokenize the input.
inputs = self.tokenizer(input_text, return_tensors="pt", padding=True,
truncation=True, max_length=512)
# Move inputs to the correct device.
device = next(self.model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Get the baseline predictions.
with torch.no_grad():
baseline_outputs = self.model(**inputs)
baseline_logits = baseline_outputs.logits[0, -1, :]
baseline_probs = F.softmax(baseline_logits, dim=-1)
baseline_top_tokens = torch.topk(baseline_probs, k=5)
# TODO: Implement the actual feature intervention.
# Simulate the effect of the intervention.
intervention_effect = {
'baseline_top_tokens': [
(self.tokenizer.decode([idx]), prob.item())
for idx, prob in zip(baseline_top_tokens.indices, baseline_top_tokens.values)
],
'intervention_layer': target_layer,
'intervention_feature': target_feature,
'intervention_strength': intervention_strength,
'effect_magnitude': 0.1,
'probability_change': 0.05
}
return intervention_effect
except Exception as e:
# Handle MPS memory issues.
print(f" Warning: Perturbation experiment failed due to device issue: {e}")
return {
'baseline_top_tokens': [],
'intervention_layer': target_layer,
'intervention_feature': target_feature,
'intervention_strength': intervention_strength,
'effect_magnitude': 0.0,
'probability_change': 0.0,
'error': str(e)
}
class AttributionGraphsPipeline:
# The main pipeline for the attribution graph analysis.
def __init__(self, config: AttributionGraphConfig):
self.config = config
self.device = torch.device(config.device)
# Load the model and tokenizer.
logger.info(f"Loading OLMo2 7B model from {config.model_path}")
self.tokenizer = AutoTokenizer.from_pretrained(config.model_path)
# Configure model loading based on the device.
if "mps" in config.device:
# MPS supports float16 but not device_map.
self.model = AutoModelForCausalLM.from_pretrained(
config.model_path,
torch_dtype=torch.float16,
device_map=None
).to(self.device)
elif "cuda" in config.device:
self.model = AutoModelForCausalLM.from_pretrained(
config.model_path,
torch_dtype=torch.float16,
device_map="auto"
)
else:
# CPU
self.model = AutoModelForCausalLM.from_pretrained(
config.model_path,
torch_dtype=torch.float32,
device_map=None
).to(self.device)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Initialize the CLT.
model_config = self.model.config.to_dict()
self.clt = CrossLayerTranscoder(model_config, config).to(self.device)
# Initialize the other components.
self.feature_visualizer = FeatureVisualizer(self.tokenizer)
self.attribution_graph = AttributionGraph(self.clt, self.tokenizer)
self.perturbation_experiments = PerturbationExperiments(self.model, self.clt, self.tokenizer)
logger.info("Attribution Graphs Pipeline initialized successfully")
def train_clt(self, training_texts: List[str]) -> Dict:
# Trains the Cross-Layer Transcoder.
logger.info("Starting CLT training...")
optimizer = torch.optim.Adam(self.clt.parameters(), lr=self.config.learning_rate)
training_stats = {
'reconstruction_losses': [],
'sparsity_losses': [],
'total_losses': []
}
for step in tqdm(range(self.config.training_steps), desc="Training CLT"):
# Sample a batch of texts.
batch_texts = np.random.choice(training_texts, size=self.config.batch_size)
total_loss = 0.0
total_recon_loss = 0.0
total_sparsity_loss = 0.0
for text in batch_texts:
# Tokenize the text.
inputs = self.tokenizer(text, return_tensors="pt", max_length=self.config.max_seq_length,
truncation=True, padding=True).to(self.device)
# Get the model activations.
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states[1:]
# Forward pass through the CLT.
feature_activations, reconstructed_outputs = self.clt(hidden_states)
# Compute the reconstruction loss.
recon_loss = 0.0
for i, (target, pred) in enumerate(zip(hidden_states, reconstructed_outputs)):
recon_loss += F.mse_loss(pred, target)
# Compute the sparsity loss.
sparsity_loss = 0.0
for features in feature_activations:
sparsity_loss += torch.mean(torch.tanh(self.config.sparsity_lambda * features))
# Total loss.
loss = (self.config.reconstruction_loss_weight * recon_loss +
self.config.sparsity_lambda * sparsity_loss)
total_loss += loss
total_recon_loss += recon_loss
total_sparsity_loss += sparsity_loss
# Average the losses.
total_loss /= self.config.batch_size
total_recon_loss /= self.config.batch_size
total_sparsity_loss /= self.config.batch_size
# Backward pass.
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
# Log the progress.
training_stats['total_losses'].append(total_loss.item())
training_stats['reconstruction_losses'].append(total_recon_loss.item())
training_stats['sparsity_losses'].append(total_sparsity_loss.item())
if step % 100 == 0:
logger.info(f"Step {step}: Total Loss = {total_loss.item():.4f}, "
f"Recon Loss = {total_recon_loss.item():.4f}, "
f"Sparsity Loss = {total_sparsity_loss.item():.4f}")
logger.info("CLT training completed")
return training_stats
def analyze_prompt(self, prompt: str, target_token_idx: int = -1) -> Dict:
# Performs a complete analysis for a single prompt.
logger.info(f"Analyzing prompt: '{prompt[:50]}...'")
# Tokenize the prompt.
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
input_tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
# Get the model activations.
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states[1:]
# Forward pass through the CLT.
feature_activations, reconstructed_outputs = self.clt(hidden_states)
logger.info(" > Starting feature visualization and interpretation...")
feature_visualizations = {}
for layer_idx, features in enumerate(feature_activations):
logger.info(f" - Processing Layer {layer_idx}...")
layer_viz = {}
# Analyze the top features for this layer.
# features shape: [batch_size, seq_len, n_features]
feature_importance = torch.mean(features, dim=(0, 1))
top_features = torch.topk(feature_importance, k=min(5, feature_importance.size(0))).indices
for feat_idx in top_features:
viz = self.feature_visualizer.visualize_feature(
feat_idx.item(), layer_idx, features[0], input_tokens
)
interpretation = self.feature_visualizer.interpret_feature(
feat_idx.item(), layer_idx, viz, self.config.qwen_api_config
)
viz['interpretation'] = interpretation
layer_viz[f"feature_{feat_idx.item()}"] = viz
feature_visualizations[f"layer_{layer_idx}"] = layer_viz
# Construct the attribution graph.
graph = self.attribution_graph.construct_graph(
input_tokens, feature_activations, target_token_idx
)
# Prune the graph.
pruned_graph = self.attribution_graph.prune_graph(self.config.pruning_threshold)
# Analyze the most important paths.
important_paths = []
if len(pruned_graph.nodes()) > 0:
# Find paths from embeddings to the output.
embedding_nodes = [node for node, type_ in self.attribution_graph.node_types.items()
if type_ == "embedding" and node in pruned_graph]
output_nodes = [node for node, type_ in self.attribution_graph.node_types.items()
if type_ == "output" and node in pruned_graph]
for emb_node in embedding_nodes[:3]:
for out_node in output_nodes:
try:
paths = list(nx.all_simple_paths(pruned_graph, emb_node, out_node, cutoff=5))
for path in paths[:2]:
path_weight = 1.0
for i in range(len(path) - 1):
edge_weight = self.attribution_graph.edge_weights.get(
(path[i], path[i+1]), 0.0
)
path_weight *= abs(edge_weight)
important_paths.append({
'path': path,
'weight': path_weight,
'description': self._describe_path(path)
})
except nx.NetworkXNoPath:
continue
# Sort paths by importance.
important_paths.sort(key=lambda x: x['weight'], reverse=True)
results = {
'prompt': prompt,
'input_tokens': input_tokens,
'feature_visualizations': feature_visualizations,
'full_graph_stats': {
'n_nodes': len(graph.nodes()),
'n_edges': len(graph.edges()),
'node_types': dict(self.attribution_graph.node_types)
},
'pruned_graph_stats': {
'n_nodes': len(pruned_graph.nodes()),
'n_edges': len(pruned_graph.edges())
},
'important_paths': important_paths[:5],
'graph': pruned_graph
}
return results
def _describe_path(self, path: List[str]) -> str:
# Generates a human-readable description of a path.
descriptions = []
for node in path:
if self.attribution_graph.node_types[node] == "embedding":
token = node.split('_')[2]
descriptions.append(f"Token '{token}'")
elif self.attribution_graph.node_types[node] == "feature":
parts = node.split('_')
layer = parts[1][1:]
feature = parts[3][1:]
# Try to get the interpretation.
key = f"L{layer}_F{feature}"
interpretation = self.feature_visualizer.feature_interpretations.get(key, "unknown")
descriptions.append(f"Feature L{layer}F{feature} ({interpretation})")
elif self.attribution_graph.node_types[node] == "output":
descriptions.append("Output")
return " β†’ ".join(descriptions)
def save_results(self, results: Dict, save_path: str):
# Saves the analysis results to a file.
# Convert the graph to a serializable format.
serializable_results = results.copy()
if 'graph' in serializable_results:
graph_data = nx.node_link_data(serializable_results['graph'])
serializable_results['graph'] = graph_data
with open(save_path, 'w', encoding='utf-8') as f:
json.dump(serializable_results, f, indent=2, default=str, ensure_ascii=False)
logger.info(f"Results saved to {save_path}")
def save_clt(self, path: str):
# Saves the trained CLT model.
torch.save(self.clt.state_dict(), path)
logger.info(f"CLT model saved to {path}")
def load_clt(self, path: str):
# Loads a trained CLT model.
self.clt.load_state_dict(torch.load(path, map_location=self.device))
self.clt.to(self.device)
self.clt.eval()
logger.info(f"Loaded CLT model from {path}")
# --- Configuration ---
MAX_SEQ_LEN = 256
N_FEATURES_PER_LAYER = 512
TRAINING_STEPS = 2500
BATCH_SIZE = 64
LEARNING_RATE = 1e-3
# German prompts for the final analysis.
ANALYSIS_PROMPTS = [
"Die Hauptstadt von Frankreich ist",
"def fakultaet(n):",
"Das literarische Stilmittel im Satz 'Der Wind flΓΌsterte durch die BΓ€ume' ist"
]
# A larger set of German prompts for training.
TRAINING_PROMPTS = [
"Die Hauptstadt von Frankreich ist", "Sein oder Nichtsein, das ist hier die Frage", "Was du heute kannst besorgen, das verschiebe nicht auf morgen",
"Der erste Mensch auf dem Mond war", "Die chemische Formel fΓΌr Wasser ist H2O.",
"Übersetze ins Englische: 'Die Katze sitzt auf der Matte.'", "def fakultaet(n):", "import numpy as np",
"Die Hauptzutaten einer Pizza sind", "Was ist das Kraftwerk der Zelle?",
"Die Gleichung E=mc^2 beschreibt die Beziehung zwischen Energie und", "Setze die Geschichte fort: Es war einmal, da war ein",
"Klassifiziere das Sentiment: 'Ich bin ΓΌberglΓΌcklich!'", "Extrahiere die EntitΓ€ten: 'Apple Inc. ist in Cupertino.'",
"Was ist die nΓ€chste Zahl: 2, 4, 8, 16, __?", "Ein rollender Stein setzt kein Moos an",
"Das Gegenteil von heiß ist", "import torch", "import pandas as pd", "class MeineKlasse:",
"def __init__(self):", "Die PrimΓ€rfarben sind", "Was ist die Hauptstadt von Japan?",
"Wer hat 'Hamlet' geschrieben?", "Die Quadratwurzel von 64 ist", "Die Sonne geht im Osten auf",
"Der Pazifische Ozean ist der grâßte Ozean der Erde.", "Die Mitochondrien sind das Kraftwerk der Zelle.",
"Was ist die Hauptstadt der Mongolei?", "Der Film 'Matrix' kann folgendem Genre zugeordnet werden:",
"Die englische Übersetzung von 'Ich mâchte bitte einen Kaffee bestellen.' lautet:",
"Das literarische Stilmittel im Satz 'Der Wind flΓΌsterte durch die BΓ€ume' ist",
"Eine Python-Funktion, die die FakultΓ€t einer Zahl berechnet, lautet:",
"Die Hauptzutat eines Negroni-Cocktails ist",
"Fasse die Handlung von 'Hamlet' in einem Satz zusammen:",
"Der Satz 'Der Kuchen wurde vom Hund gefressen' steht in folgender Form:",
"Eine gute Überschrift für einen Artikel über einen neuen Durchbruch in der Batterietechnologie wÀre:"
]
# --- Qwen API for Feature Interpretation ---
@torch.no_grad()
def get_feature_interpretation_with_qwen(
api_config: dict,
top_tokens: list[str],
feature_name: str,
layer_index: int,
max_retries: int = 3,
initial_backoff: float = 2.0
) -> str:
# Generates a high-quality interpretation for a feature using the Qwen API.
if not api_config or not api_config.get('api_key'):
logger.warning("Qwen API not configured. Skipping interpretation.")
return "API not configured"
headers = {
"Authorization": f"Bearer {api_config['api_key']}",
"Content-Type": "application/json"
}
# Create a specialized German prompt.
prompt_text = f"""
Sie sind ein Experte fΓΌr die Interpretierbarkeit von Transformern. Ein Merkmal in einem Sprachmodell (Merkmal '{feature_name}' auf Schicht {layer_index}) wird am stΓ€rksten durch die folgenden Token aktiviert:
{', '.join(f"'{token}'" for token in top_tokens)}
Was ist, basierend *nur* auf diesen Token, die wahrscheinlichste Funktion oder Rolle dieses Merkmals?
Ihre Antwort muss ein kurzer, prΓ€gnanter Ausdruck sein (z.B. "Erkennen von Eigennamen", "Identifizieren von JSON-Syntax", "VervollstΓ€ndigen von Listen", "Erkennen negativer Stimmung"). Schreiben Sie keinen ganzen Satz.
"""
data = {
"model": api_config["model"],
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": prompt_text}]
}
],
"max_tokens": 50,
"temperature": 0.1,
"top_p": 0.9,
"seed": 42
}
logger.info(f" > Interpreting {feature_name} (Layer {layer_index})...")
for attempt in range(max_retries):
try:
logger.info(f" - Attempt {attempt + 1}/{max_retries}: Sending request to Qwen API...")
response = requests.post(
f"{api_config['api_endpoint']}/chat/completions",
headers=headers,
json=data,
timeout=60
)
response.raise_for_status()
result = response.json()
interpretation = result["choices"][0]["message"]["content"].strip()
# Remove quotes from the output.
if interpretation.startswith('"') and interpretation.endswith('"'):
interpretation = interpretation[1:-1]
logger.info(f" - Success! Interpretation: '{interpretation}'")
return interpretation
except requests.exceptions.RequestException as e:
logger.warning(f" - Qwen API request failed (Attempt {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
backoff_time = initial_backoff * (2 ** attempt)
logger.info(f" - Retrying in {backoff_time:.1f} seconds...")
time.sleep(backoff_time)
else:
logger.error(" - Max retries reached. Failing.")
return f"API Error: {e}"
except (KeyError, IndexError) as e:
logger.error(f" - Failed to parse Qwen API response: {e}")
return "API Error: Invalid response format"
finally:
# Add a delay to respect API rate limits.
time.sleep(2.1)
return "API Error: Max retries exceeded"
def train_transcoder(transcoder, model, tokenizer, training_prompts, device, steps=1000, batch_size=16, optimizer=None):
# Trains the Cross-Layer Transcoder.
transcoder.train()
# Use a progress bar for visual feedback.
progress_bar = tqdm(range(steps), desc="Training CLT")
for step in progress_bar:
# Get a random batch of prompts.
batch_prompts = random.choices(training_prompts, k=batch_size)
# Tokenize the batch.
inputs = tokenizer(
batch_prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=MAX_SEQ_LEN
)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Get the model activations.
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states[1:]
# Forward pass through the CLT.
feature_activations, reconstructed_outputs = transcoder(hidden_states)
# Compute the reconstruction loss.
recon_loss = 0.0
for i, (target, pred) in enumerate(zip(hidden_states, reconstructed_outputs)):
recon_loss += F.mse_loss(pred, target)
# Compute the sparsity loss.
sparsity_loss = 0.0
for features in feature_activations:
sparsity_loss += torch.mean(torch.tanh(0.01 * features))
# Total loss.
loss = (0.8 * recon_loss + 0.2 * sparsity_loss)
if optimizer:
optimizer.zero_grad()
loss.backward()
optimizer.step()
progress_bar.set_postfix({
"Recon Loss": f"{recon_loss.item():.4f}",
"Sparsity Loss": f"{sparsity_loss.item():.4f}",
"Total Loss": f"{loss.item():.4f}"
})
def generate_feature_visualizations(transcoder, model, tokenizer, prompt, device, qwen_api_config=None, graph_config: Optional[AttributionGraphConfig] = None):
# Generates feature visualizations and interpretations for a prompt.
# Tokenize the prompt.
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=MAX_SEQ_LEN
)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Get the model activations.
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states[1:]
# Forward pass through the CLT.
feature_activations, reconstructed_outputs = transcoder(hidden_states)
# Visualize the features.
feature_visualizations = {}
for layer_idx, features in enumerate(feature_activations):
layer_viz = {}
# Analyze the top features for this layer.
# features shape: [batch_size, seq_len, n_features]
feature_importance = torch.mean(features, dim=(0, 1))
top_features = torch.topk(feature_importance, k=min(5, feature_importance.size(0))).indices
for feat_idx in top_features:
viz = FeatureVisualizer(tokenizer).visualize_feature(
feat_idx.item(), layer_idx, features[0], tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
)
interpretation = FeatureVisualizer(tokenizer).interpret_feature(
feat_idx.item(), layer_idx, viz, qwen_api_config
)
viz['interpretation'] = interpretation
layer_viz[f"feature_{feat_idx.item()}"] = viz
feature_visualizations[f"layer_{layer_idx}"] = layer_viz
# Construct the attribution graph.
if graph_config is None:
graph_config = AttributionGraphConfig()
attribution_graph = AttributionGraph(transcoder, tokenizer, graph_config)
graph = attribution_graph.construct_graph(
tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]), feature_activations, -1
)
# Prune the graph.
pruned_graph = attribution_graph.prune_graph(0.8)
# Analyze the most important paths.
important_paths = []
if len(pruned_graph.nodes()) > 0:
# Find paths from embeddings to the output.
embedding_nodes = [node for node, type_ in attribution_graph.node_types.items()
if type_ == "embedding" and node in pruned_graph]
output_nodes = [node for node, type_ in attribution_graph.node_types.items()
if type_ == "output" and node in pruned_graph]
for emb_node in embedding_nodes[:3]:
for out_node in output_nodes:
try:
paths = list(nx.all_simple_paths(pruned_graph, emb_node, out_node, cutoff=5))
for path in paths[:2]:
path_weight = 1.0
for i in range(len(path) - 1):
edge_weight = attribution_graph.edge_weights.get(
(path[i], path[i+1]), 0.0
)
path_weight *= abs(edge_weight)
important_paths.append({
'path': path,
'weight': path_weight,
'description': attribution_graph._describe_path(path)
})
except nx.NetworkXNoPath:
continue
# Sort paths by importance.
important_paths.sort(key=lambda x: x['weight'], reverse=True)
return {
"prompt": prompt,
"full_graph_stats": {
"n_nodes": len(graph.nodes()),
"n_edges": len(graph.edges()),
"node_types": dict(attribution_graph.node_types)
},
"pruned_graph_stats": {
"n_nodes": len(pruned_graph.nodes()),
"n_edges": len(pruned_graph.edges())
},
"feature_visualizations": feature_visualizations,
"important_paths": important_paths[:5]
}
def main():
# Main function to run the analysis for a single prompt.
# Set a seed for reproducibility.
set_seed()
# --- Argument Parser ---
parser = argparse.ArgumentParser(description="Run Attribution Graph analysis for a single prompt.")
parser.add_argument(
'--prompt-index',
type=int,
required=True,
help=f"The 0-based index of the prompt to analyze from the ANALYSIS_PROMPTS list (0 to {len(ANALYSIS_PROMPTS) - 1})."
)
parser.add_argument(
'--force-retrain-clt',
action='store_true',
help="Force re-training of the Cross-Layer Transcoder, even if a saved model exists."
)
args = parser.parse_args()
prompt_idx = args.prompt_index
if not (0 <= prompt_idx < len(ANALYSIS_PROMPTS)):
print(f"❌ Error: --prompt-index must be between 0 and {len(ANALYSIS_PROMPTS) - 1}.")
return
# Get the API config from the utility function.
qwen_api_config = init_qwen_api()
# Configuration
config = AttributionGraphConfig(
model_path="./models/OLMo-2-1124-7B",
n_features_per_layer=512,
training_steps=500,
batch_size=4,
max_seq_length=256,
learning_rate=1e-4,
sparsity_lambda=0.01,
qwen_api_config=qwen_api_config
)
print("Attribution Graphs for OLMo2 7B - Single Prompt Pipeline (German)")
print("=" * 50)
print(f"Model path: {config.model_path}")
print(f"Device: {config.device}")
try:
# Initialize the full pipeline.
print("πŸš€ Initializing Attribution Graphs Pipeline...")
pipeline = AttributionGraphsPipeline(config)
print("βœ“ Pipeline initialized successfully")
print()
# Load an existing CLT model or train a new one.
if os.path.exists(CLT_SAVE_PATH) and not args.force_retrain_clt:
print(f"🧠 Loading existing CLT model from {CLT_SAVE_PATH}...")
pipeline.load_clt(CLT_SAVE_PATH)
print("βœ“ CLT model loaded successfully.")
else:
if args.force_retrain_clt and os.path.exists(CLT_SAVE_PATH):
print("πŸƒβ€β™‚οΈ --force-retrain-clt flag is set. Overwriting existing model.")
# Train a new CLT model.
print("πŸ“š Training a new CLT model...")
print(f" Training on {len(TRAINING_PROMPTS)} example texts...")
training_stats = pipeline.train_clt(TRAINING_PROMPTS)
print("βœ“ CLT training completed.")
# Save the new model.
pipeline.save_clt(CLT_SAVE_PATH)
print(f" Saved trained model to {CLT_SAVE_PATH} for future use.")
print()
# Analyze the selected prompt.
prompt_to_analyze = ANALYSIS_PROMPTS[prompt_idx]
print(f"πŸ” Analyzing prompt {prompt_idx + 1}/{len(ANALYSIS_PROMPTS)}: '{prompt_to_analyze}'")
analysis = pipeline.analyze_prompt(prompt_to_analyze, target_token_idx=-1)
# Display the key results.
print(f" βœ“ Tokenized into {len(analysis['input_tokens'])} tokens")
print(f" βœ“ Full graph: {analysis['full_graph_stats']['n_nodes']} nodes, {analysis['full_graph_stats']['n_edges']} edges")
print(f" βœ“ Pruned graph: {analysis['pruned_graph_stats']['n_nodes']} nodes, {analysis['pruned_graph_stats']['n_edges']} edges")
# Show the top features.
print(" πŸ“Š Top active features:")
for layer_name, layer_features in list(analysis['feature_visualizations'].items())[:3]:
print(f" {layer_name}:")
for feat_name, feat_data in list(layer_features.items())[:2]:
print(f" {feat_name}: {feat_data['interpretation']} (max: {feat_data['max_activation']:.3f})")
print()
# Run a perturbation experiment.
print("πŸ§ͺ Running perturbation experiment...")
# (No need to pass training_stats to the experiment)
if analysis['feature_visualizations']:
first_layer_key = next(iter(analysis['feature_visualizations']), None)
if first_layer_key:
layer_idx = int(first_layer_key.split('_')[1])
first_feature_key = next(iter(analysis['feature_visualizations'][first_layer_key]), None)
if first_feature_key:
feature_idx = int(first_feature_key.split('_')[1])
ablation_result = pipeline.perturbation_experiments.feature_ablation_experiment(
prompt_to_analyze, layer_idx, feature_idx, intervention_strength=3.0
)
print(f" Ablated L{layer_idx}F{feature_idx}: Ξ” probability = {ablation_result['probability_change']:.4f}")
print("βœ“ Perturbation experiment completed")
print()
# Generate a visualization for the prompt.
print("πŸ“ˆ Generating visualization...")
if 'graph' in analysis and analysis['pruned_graph_stats']['n_nodes'] > 0:
viz_path = os.path.join(RESULTS_DIR, f"attribution_graph_prompt_de_{prompt_idx + 1}.png")
pipeline.attribution_graph.visualize_graph(analysis['graph'], save_path=viz_path)
print(f" βœ“ Graph visualization saved to {viz_path}")
else:
print(" - Skipping visualization as no graph was generated or it was empty.")
# Save the results in a format for the web app.
save_path = os.path.join(RESULTS_DIR, f"attribution_graphs_results_de_prompt_{prompt_idx + 1}.json")
# Create a JSON file that can be merged with others.
final_results = {
"analyses": {
f"prompt_de_{prompt_idx + 1}": analysis
},
"config": config.__dict__,
"timestamp": str(time.time())
}
# The web page doesn't use the graph object, so remove it.
if 'graph' in final_results['analyses'][f"prompt_de_{prompt_idx + 1}"]:
del final_results['analyses'][f"prompt_de_{prompt_idx + 1}"]['graph']
pipeline.save_results(final_results, save_path)
print(f"πŸ’Ύ Results saved to {save_path}")
print("\nπŸŽ‰ Analysis for this prompt completed successfully!")
except Exception as e:
print(f"❌ Error during execution: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()