Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # This script generates attribution graphs for the German OLMo2 7B model. | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from typing import Dict, List, Tuple, Optional, Any | |
| import json | |
| import logging | |
| from pathlib import Path | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from collections import defaultdict | |
| import networkx as nx | |
| from dataclasses import dataclass | |
| from tqdm import tqdm | |
| import pickle | |
| import requests | |
| import time | |
| import random | |
| import os | |
| import argparse | |
| # --- Add this block to fix the import path --- | |
| import sys | |
| from pathlib import Path | |
| sys.path.append(str(Path(__file__).resolve().parent.parent)) | |
| # --------------------------------------------- | |
| from utils import init_qwen_api, set_seed | |
| # --- Constants --- | |
| # Configuration for the attribution graph generation pipeline. | |
| RESULTS_DIR = "circuit_analysis/results" | |
| CLT_SAVE_PATH = "circuit_analysis/models/clt_model_de.pth" | |
| # Configure logging. | |
| logging.basicConfig(level=logging.INFO, format='%(asctime=s - %(levelname=s - %(message=s') | |
| logger = logging.getLogger(__name__) | |
| # Set the device for training. | |
| if torch.backends.mps.is_available(): | |
| DEVICE = torch.device("mps") | |
| logger.info("Using MPS (Metal Performance Shaders) for GPU acceleration") | |
| elif torch.cuda.is_available(): | |
| DEVICE = torch.device("cuda") | |
| logger.info("Using CUDA for GPU acceleration") | |
| else: | |
| DEVICE = torch.device("cpu") | |
| logger.info("Using CPU") | |
| class AttributionGraphConfig: | |
| # Configuration for building the attribution graph. | |
| model_path: str = "./models/OLMo-2-1124-7B" | |
| max_seq_length: int = 512 | |
| n_features_per_layer: int = 512 | |
| sparsity_lambda: float = 0.01 | |
| reconstruction_loss_weight: float = 1.0 | |
| batch_size: int = 8 | |
| learning_rate: float = 1e-4 | |
| training_steps: int = 1000 | |
| device: str = str(DEVICE) | |
| pruning_threshold: float = 0.8 # For graph pruning | |
| intervention_strength: float = 5.0 # For perturbation experiments | |
| qwen_api_config: Optional[Dict[str, str]] = None | |
| class JumpReLU(nn.Module): | |
| # The JumpReLU activation function. | |
| def __init__(self, threshold: float = 0.0): | |
| super().__init__() | |
| self.threshold = threshold | |
| def forward(self, x): | |
| return F.relu(x - self.threshold) | |
| class CrossLayerTranscoder(nn.Module): | |
| # The Cross-Layer Transcoder (CLT) model. | |
| def __init__(self, model_config: Dict, clt_config: AttributionGraphConfig): | |
| super().__init__() | |
| self.config = clt_config | |
| self.model_config = model_config | |
| self.n_layers = model_config['num_hidden_layers'] | |
| self.hidden_size = model_config['hidden_size'] | |
| self.n_features = clt_config.n_features_per_layer | |
| # Encoder weights for each layer. | |
| self.encoders = nn.ModuleList([ | |
| nn.Linear(self.hidden_size, self.n_features, bias=False) | |
| for _ in range(self.n_layers) | |
| ]) | |
| # Decoder weights for cross-layer connections. | |
| self.decoders = nn.ModuleDict() | |
| for source_layer in range(self.n_layers): | |
| for target_layer in range(source_layer, self.n_layers): | |
| key = f"{source_layer}_to_{target_layer}" | |
| self.decoders[key] = nn.Linear(self.n_features, self.hidden_size, bias=False) | |
| # The activation function. | |
| self.activation = JumpReLU(threshold=0.0) | |
| # Initialize the weights. | |
| self._init_weights() | |
| def _init_weights(self): | |
| # Initializes the weights with small random values. | |
| for module in self.modules(): | |
| if isinstance(module, nn.Linear): | |
| nn.init.normal_(module.weight, mean=0.0, std=0.01) | |
| def encode(self, layer_idx: int, residual_activations: torch.Tensor) -> torch.Tensor: | |
| # Encodes residual stream activations to feature activations. | |
| return self.activation(self.encoders[layer_idx](residual_activations)) | |
| def decode(self, source_layer: int, target_layer: int, feature_activations: torch.Tensor) -> torch.Tensor: | |
| # Decodes feature activations to the MLP output space. | |
| key = f"{source_layer}_to_{target_layer}" | |
| return self.decoders[key](feature_activations) | |
| def forward(self, residual_activations: List[torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: | |
| # The forward pass of the CLT. | |
| feature_activations = [] | |
| reconstructed_mlp_outputs = [] | |
| # Encode features for each layer. | |
| for layer_idx, residual in enumerate(residual_activations): | |
| features = self.encode(layer_idx, residual) | |
| feature_activations.append(features) | |
| # Reconstruct MLP outputs with cross-layer connections. | |
| for target_layer in range(self.n_layers): | |
| reconstruction = torch.zeros_like(residual_activations[target_layer]) | |
| # Sum contributions from all previous layers. | |
| for source_layer in range(target_layer + 1): | |
| decoded = self.decode(source_layer, target_layer, feature_activations[source_layer]) | |
| reconstruction += decoded | |
| reconstructed_mlp_outputs.append(reconstruction) | |
| return feature_activations, reconstructed_mlp_outputs | |
| class FeatureVisualizer: | |
| # A class to visualize and interpret individual features. | |
| def __init__(self, tokenizer): | |
| self.tokenizer = tokenizer | |
| self.feature_interpretations = {} | |
| def visualize_feature(self, feature_idx: int, layer_idx: int, | |
| activations: torch.Tensor, input_tokens: List[str], | |
| top_k: int = 10) -> Dict: | |
| # Creates a visualization for a single feature. | |
| feature_acts = activations[:, feature_idx].detach().cpu().numpy() | |
| # Find the top activating positions. | |
| top_positions = np.argsort(feature_acts)[-top_k:][::-1] | |
| visualization = { | |
| 'feature_idx': feature_idx, | |
| 'layer_idx': layer_idx, | |
| 'max_activation': float(feature_acts.max()), | |
| 'mean_activation': float(feature_acts.mean()), | |
| 'sparsity': float((feature_acts > 0.1).mean()), | |
| 'top_activations': [] | |
| } | |
| for pos in top_positions: | |
| if pos < len(input_tokens): | |
| visualization['top_activations'].append({ | |
| 'token': input_tokens[pos], | |
| 'position': int(pos), | |
| 'activation': float(feature_acts[pos]) | |
| }) | |
| return visualization | |
| def interpret_feature(self, feature_idx: int, layer_idx: int, | |
| visualization_data: Dict, | |
| qwen_api_config: Optional[Dict[str, str]] = None) -> str: | |
| # Interprets a feature based on its top activating tokens. | |
| top_tokens = [item['token'] for item in visualization_data['top_activations']] | |
| # Use the Qwen API if it is configured. | |
| if qwen_api_config and qwen_api_config.get('api_key'): | |
| feature_name = f"L{layer_idx}_F{feature_idx}" | |
| interpretation = get_feature_interpretation_with_qwen( | |
| qwen_api_config, top_tokens, feature_name, layer_idx | |
| ) | |
| else: | |
| # Use a simple heuristic as a fallback. | |
| if len(set(top_tokens)) == 1: | |
| interpretation = f"Spezifischer Token: '{top_tokens[0]}'" | |
| elif all(token.isalpha() for token in top_tokens): | |
| interpretation = "Wort/alphabetische Tokens" | |
| elif all(token.isdigit() for token in top_tokens): | |
| interpretation = "Numerische Tokens" | |
| elif all(token in '.,!?;:' for token in top_tokens): | |
| interpretation = "Interpunktion" | |
| else: | |
| interpretation = "Gemischte/polysemische Merkmale" | |
| self.feature_interpretations[f"L{layer_idx}_F{feature_idx}"] = interpretation | |
| return interpretation | |
| class AttributionGraph: | |
| # A class to construct and analyze attribution graphs. | |
| def __init__(self, clt: CrossLayerTranscoder, tokenizer): | |
| self.clt = clt | |
| self.tokenizer = tokenizer | |
| self.graph = nx.DiGraph() | |
| self.node_types = {} # Track node types (feature, embedding, error, output) | |
| self.edge_weights = {} | |
| def compute_virtual_weights(self, source_layer: int, target_layer: int, | |
| source_feature: int, target_feature: int) -> float: | |
| # Computes the virtual weight between two features. | |
| if target_layer <= source_layer: | |
| return 0.0 | |
| # Get the encoder and decoder weights. | |
| encoder_weight = self.clt.encoders[target_layer].weight[target_feature] | |
| total_weight = 0.0 | |
| for intermediate_layer in range(source_layer, target_layer): | |
| decoder_key = f"{source_layer}_to_{intermediate_layer}" | |
| if decoder_key in self.clt.decoders: | |
| decoder_weight = self.clt.decoders[decoder_key].weight[:, source_feature] | |
| # The virtual weight is the inner product. | |
| virtual_weight = torch.dot(decoder_weight, encoder_weight).item() | |
| total_weight += virtual_weight | |
| return total_weight | |
| def construct_graph(self, input_tokens: List[str], | |
| feature_activations: List[torch.Tensor], | |
| target_token_idx: int = -1) -> nx.DiGraph: | |
| # Constructs the attribution graph for a prompt. | |
| self.graph.clear() | |
| self.node_types.clear() | |
| self.edge_weights.clear() | |
| seq_len = len(input_tokens) | |
| n_layers = len(feature_activations) | |
| # Add embedding nodes for the input tokens. | |
| for i, token in enumerate(input_tokens): | |
| node_id = f"emb_{i}_{token}" | |
| self.graph.add_node(node_id) | |
| self.node_types[node_id] = "embedding" | |
| # Add nodes for the features. | |
| active_features = {} | |
| max_features_per_layer = 20 | |
| for layer_idx, features in enumerate(feature_activations): | |
| # features shape: [batch_size, seq_len, n_features] | |
| batch_size, seq_len_layer, n_features = features.shape | |
| # Get the top activating features for this layer. | |
| layer_activations = features[0].mean(dim=0) | |
| top_features = torch.topk(layer_activations, | |
| k=min(max_features_per_layer, n_features)).indices | |
| for token_pos in range(min(seq_len, seq_len_layer)): | |
| for feat_idx in top_features: | |
| activation = features[0, token_pos, feat_idx.item()].item() | |
| if activation > 0.05: | |
| node_id = f"feat_L{layer_idx}_T{token_pos}_F{feat_idx.item()}" | |
| self.graph.add_node(node_id) | |
| self.node_types[node_id] = "feature" | |
| active_features[node_id] = { | |
| 'layer': layer_idx, | |
| 'token_pos': token_pos, | |
| 'feature_idx': feat_idx.item(), | |
| 'activation': activation | |
| } | |
| # Add an output node for the target token. | |
| output_node = f"output_{target_token_idx}" | |
| self.graph.add_node(output_node) | |
| self.node_types[output_node] = "output" | |
| # Add edges based on virtual weights and activations. | |
| feature_nodes = [node for node, type_ in self.node_types.items() if type_ == "feature"] | |
| print(f" Building attribution graph: {len(feature_nodes)} feature nodes, {len(self.graph.nodes())} total nodes") | |
| # Limit the number of edges to compute. | |
| max_edges_per_node = 5 | |
| for i, source_node in enumerate(feature_nodes): | |
| if i % 50 == 0: | |
| print(f" Processing node {i+1}/{len(feature_nodes)}") | |
| edges_added = 0 | |
| source_info = active_features[source_node] | |
| source_activation = source_info['activation'] | |
| # Add edges to other features. | |
| for target_node in feature_nodes: | |
| if source_node == target_node or edges_added >= max_edges_per_node: | |
| continue | |
| target_info = active_features[target_node] | |
| # Only add edges that go forward in the network. | |
| if (target_info['layer'] > source_info['layer'] or | |
| (target_info['layer'] == source_info['layer'] and | |
| target_info['token_pos'] > source_info['token_pos'])): | |
| virtual_weight = self.compute_virtual_weights( | |
| source_info['layer'], target_info['layer'], | |
| source_info['feature_idx'], target_info['feature_idx'] | |
| ) | |
| if abs(virtual_weight) > 0.05: | |
| edge_weight = source_activation * virtual_weight | |
| self.graph.add_edge(source_node, target_node, weight=edge_weight) | |
| self.edge_weights[(source_node, target_node)] = edge_weight | |
| edges_added += 1 | |
| # Add edges to the output node. | |
| if source_info['layer'] >= n_layers - 2: | |
| output_weight = source_activation * 0.1 | |
| self.graph.add_edge(source_node, output_node, weight=output_weight) | |
| self.edge_weights[(source_node, output_node)] = output_weight | |
| # Add edges from embeddings to early features. | |
| for emb_node in [node for node, type_ in self.node_types.items() if type_ == "embedding"]: | |
| token_idx = int(emb_node.split('_')[1]) | |
| for feat_node in feature_nodes: | |
| feat_info = active_features[feat_node] | |
| if feat_info['layer'] == 0 and feat_info['token_pos'] == token_idx: | |
| # Direct connection from an embedding to a first-layer feature. | |
| weight = feat_info['activation'] * 0.5 | |
| self.graph.add_edge(emb_node, feat_node, weight=weight) | |
| self.edge_weights[(emb_node, feat_node)] = weight | |
| return self.graph | |
| def prune_graph(self, threshold: float = 0.8) -> nx.DiGraph: | |
| # Prunes the graph to keep only the most important nodes. | |
| # Calculate node importance based on edge weights. | |
| node_importance = defaultdict(float) | |
| for (source, target), weight in self.edge_weights.items(): | |
| node_importance[source] += abs(weight) | |
| node_importance[target] += abs(weight) | |
| # Keep the top nodes by importance. | |
| sorted_nodes = sorted(node_importance.items(), key=lambda x: x[1], reverse=True) | |
| n_keep = int(len(sorted_nodes) * threshold) | |
| important_nodes = set([node for node, _ in sorted_nodes[:n_keep]]) | |
| # Always keep the output and embedding nodes. | |
| for node, type_ in self.node_types.items(): | |
| if type_ in ["output", "embedding"]: | |
| important_nodes.add(node) | |
| # Create the pruned graph. | |
| pruned_graph = self.graph.subgraph(important_nodes).copy() | |
| return pruned_graph | |
| def visualize_graph(self, graph: nx.DiGraph = None, save_path: str = None): | |
| # Visualizes the attribution graph. | |
| if graph is None: | |
| graph = self.graph | |
| plt.figure(figsize=(12, 8)) | |
| # Create a layout for the graph. | |
| pos = nx.spring_layout(graph, k=1, iterations=50) | |
| # Color the nodes by type. | |
| node_colors = [] | |
| for node in graph.nodes(): | |
| node_type = self.node_types.get(node, "unknown") | |
| if node_type == "embedding": | |
| node_colors.append('lightblue') | |
| elif node_type == "feature": | |
| node_colors.append('lightgreen') | |
| elif node_type == "output": | |
| node_colors.append('orange') | |
| else: | |
| node_colors.append('gray') | |
| # Draw the nodes. | |
| nx.draw_networkx_nodes(graph, pos, node_color=node_colors, | |
| node_size=300, alpha=0.8) | |
| # Draw the edges with thickness based on weight. | |
| edges = graph.edges() | |
| edge_weights = [abs(self.edge_weights.get((u, v), 0.1)) for u, v in edges] | |
| max_weight = max(edge_weights) if edge_weights else 1 | |
| edge_widths = [w / max_weight * 3 for w in edge_weights] | |
| nx.draw_networkx_edges(graph, pos, width=edge_widths, alpha=0.6, | |
| edge_color='gray', arrows=True) | |
| # Draw the labels. | |
| nx.draw_networkx_labels(graph, pos, font_size=8) | |
| plt.title("Attribution Graph (German)") | |
| plt.axis('off') | |
| if save_path: | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| plt.show() | |
| class PerturbationExperiments: | |
| # Conducts perturbation experiments to validate hypotheses. | |
| def __init__(self, model, clt: CrossLayerTranscoder, tokenizer): | |
| self.model = model | |
| self.clt = clt | |
| self.tokenizer = tokenizer | |
| def feature_ablation_experiment(self, input_text: str, | |
| target_layer: int, target_feature: int, | |
| intervention_strength: float = 5.0) -> Dict: | |
| # Ablates a feature and measures the effect on the model's output. | |
| try: | |
| # Clear the MPS cache to prevent memory issues. | |
| if torch.backends.mps.is_available(): | |
| torch.mps.empty_cache() | |
| # Tokenize the input. | |
| inputs = self.tokenizer(input_text, return_tensors="pt", padding=True, | |
| truncation=True, max_length=512) | |
| # Move inputs to the correct device. | |
| device = next(self.model.parameters()).device | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Get the baseline predictions. | |
| with torch.no_grad(): | |
| baseline_outputs = self.model(**inputs) | |
| baseline_logits = baseline_outputs.logits[0, -1, :] | |
| baseline_probs = F.softmax(baseline_logits, dim=-1) | |
| baseline_top_tokens = torch.topk(baseline_probs, k=5) | |
| # TODO: Implement the actual feature intervention. | |
| # Simulate the effect of the intervention. | |
| intervention_effect = { | |
| 'baseline_top_tokens': [ | |
| (self.tokenizer.decode([idx]), prob.item()) | |
| for idx, prob in zip(baseline_top_tokens.indices, baseline_top_tokens.values) | |
| ], | |
| 'intervention_layer': target_layer, | |
| 'intervention_feature': target_feature, | |
| 'intervention_strength': intervention_strength, | |
| 'effect_magnitude': 0.1, | |
| 'probability_change': 0.05 | |
| } | |
| return intervention_effect | |
| except Exception as e: | |
| # Handle MPS memory issues. | |
| print(f" Warning: Perturbation experiment failed due to device issue: {e}") | |
| return { | |
| 'baseline_top_tokens': [], | |
| 'intervention_layer': target_layer, | |
| 'intervention_feature': target_feature, | |
| 'intervention_strength': intervention_strength, | |
| 'effect_magnitude': 0.0, | |
| 'probability_change': 0.0, | |
| 'error': str(e) | |
| } | |
| class AttributionGraphsPipeline: | |
| # The main pipeline for the attribution graph analysis. | |
| def __init__(self, config: AttributionGraphConfig): | |
| self.config = config | |
| self.device = torch.device(config.device) | |
| # Load the model and tokenizer. | |
| logger.info(f"Loading OLMo2 7B model from {config.model_path}") | |
| self.tokenizer = AutoTokenizer.from_pretrained(config.model_path) | |
| # Configure model loading based on the device. | |
| if "mps" in config.device: | |
| # MPS supports float16 but not device_map. | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| config.model_path, | |
| torch_dtype=torch.float16, | |
| device_map=None | |
| ).to(self.device) | |
| elif "cuda" in config.device: | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| config.model_path, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| else: | |
| # CPU | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| config.model_path, | |
| torch_dtype=torch.float32, | |
| device_map=None | |
| ).to(self.device) | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| # Initialize the CLT. | |
| model_config = self.model.config.to_dict() | |
| self.clt = CrossLayerTranscoder(model_config, config).to(self.device) | |
| # Initialize the other components. | |
| self.feature_visualizer = FeatureVisualizer(self.tokenizer) | |
| self.attribution_graph = AttributionGraph(self.clt, self.tokenizer) | |
| self.perturbation_experiments = PerturbationExperiments(self.model, self.clt, self.tokenizer) | |
| logger.info("Attribution Graphs Pipeline initialized successfully") | |
| def train_clt(self, training_texts: List[str]) -> Dict: | |
| # Trains the Cross-Layer Transcoder. | |
| logger.info("Starting CLT training...") | |
| optimizer = torch.optim.Adam(self.clt.parameters(), lr=self.config.learning_rate) | |
| training_stats = { | |
| 'reconstruction_losses': [], | |
| 'sparsity_losses': [], | |
| 'total_losses': [] | |
| } | |
| for step in tqdm(range(self.config.training_steps), desc="Training CLT"): | |
| # Sample a batch of texts. | |
| batch_texts = np.random.choice(training_texts, size=self.config.batch_size) | |
| total_loss = 0.0 | |
| total_recon_loss = 0.0 | |
| total_sparsity_loss = 0.0 | |
| for text in batch_texts: | |
| # Tokenize the text. | |
| inputs = self.tokenizer(text, return_tensors="pt", max_length=self.config.max_seq_length, | |
| truncation=True, padding=True).to(self.device) | |
| # Get the model activations. | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs, output_hidden_states=True) | |
| hidden_states = outputs.hidden_states[1:] | |
| # Forward pass through the CLT. | |
| feature_activations, reconstructed_outputs = self.clt(hidden_states) | |
| # Compute the reconstruction loss. | |
| recon_loss = 0.0 | |
| for i, (target, pred) in enumerate(zip(hidden_states, reconstructed_outputs)): | |
| recon_loss += F.mse_loss(pred, target) | |
| # Compute the sparsity loss. | |
| sparsity_loss = 0.0 | |
| for features in feature_activations: | |
| sparsity_loss += torch.mean(torch.tanh(self.config.sparsity_lambda * features)) | |
| # Total loss. | |
| loss = (self.config.reconstruction_loss_weight * recon_loss + | |
| self.config.sparsity_lambda * sparsity_loss) | |
| total_loss += loss | |
| total_recon_loss += recon_loss | |
| total_sparsity_loss += sparsity_loss | |
| # Average the losses. | |
| total_loss /= self.config.batch_size | |
| total_recon_loss /= self.config.batch_size | |
| total_sparsity_loss /= self.config.batch_size | |
| # Backward pass. | |
| optimizer.zero_grad() | |
| total_loss.backward() | |
| optimizer.step() | |
| # Log the progress. | |
| training_stats['total_losses'].append(total_loss.item()) | |
| training_stats['reconstruction_losses'].append(total_recon_loss.item()) | |
| training_stats['sparsity_losses'].append(total_sparsity_loss.item()) | |
| if step % 100 == 0: | |
| logger.info(f"Step {step}: Total Loss = {total_loss.item():.4f}, " | |
| f"Recon Loss = {total_recon_loss.item():.4f}, " | |
| f"Sparsity Loss = {total_sparsity_loss.item():.4f}") | |
| logger.info("CLT training completed") | |
| return training_stats | |
| def analyze_prompt(self, prompt: str, target_token_idx: int = -1) -> Dict: | |
| # Performs a complete analysis for a single prompt. | |
| logger.info(f"Analyzing prompt: '{prompt[:50]}...'") | |
| # Tokenize the prompt. | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
| input_tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) | |
| # Get the model activations. | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs, output_hidden_states=True) | |
| hidden_states = outputs.hidden_states[1:] | |
| # Forward pass through the CLT. | |
| feature_activations, reconstructed_outputs = self.clt(hidden_states) | |
| logger.info(" > Starting feature visualization and interpretation...") | |
| feature_visualizations = {} | |
| for layer_idx, features in enumerate(feature_activations): | |
| logger.info(f" - Processing Layer {layer_idx}...") | |
| layer_viz = {} | |
| # Analyze the top features for this layer. | |
| # features shape: [batch_size, seq_len, n_features] | |
| feature_importance = torch.mean(features, dim=(0, 1)) | |
| top_features = torch.topk(feature_importance, k=min(5, feature_importance.size(0))).indices | |
| for feat_idx in top_features: | |
| viz = self.feature_visualizer.visualize_feature( | |
| feat_idx.item(), layer_idx, features[0], input_tokens | |
| ) | |
| interpretation = self.feature_visualizer.interpret_feature( | |
| feat_idx.item(), layer_idx, viz, self.config.qwen_api_config | |
| ) | |
| viz['interpretation'] = interpretation | |
| layer_viz[f"feature_{feat_idx.item()}"] = viz | |
| feature_visualizations[f"layer_{layer_idx}"] = layer_viz | |
| # Construct the attribution graph. | |
| graph = self.attribution_graph.construct_graph( | |
| input_tokens, feature_activations, target_token_idx | |
| ) | |
| # Prune the graph. | |
| pruned_graph = self.attribution_graph.prune_graph(self.config.pruning_threshold) | |
| # Analyze the most important paths. | |
| important_paths = [] | |
| if len(pruned_graph.nodes()) > 0: | |
| # Find paths from embeddings to the output. | |
| embedding_nodes = [node for node, type_ in self.attribution_graph.node_types.items() | |
| if type_ == "embedding" and node in pruned_graph] | |
| output_nodes = [node for node, type_ in self.attribution_graph.node_types.items() | |
| if type_ == "output" and node in pruned_graph] | |
| for emb_node in embedding_nodes[:3]: | |
| for out_node in output_nodes: | |
| try: | |
| paths = list(nx.all_simple_paths(pruned_graph, emb_node, out_node, cutoff=5)) | |
| for path in paths[:2]: | |
| path_weight = 1.0 | |
| for i in range(len(path) - 1): | |
| edge_weight = self.attribution_graph.edge_weights.get( | |
| (path[i], path[i+1]), 0.0 | |
| ) | |
| path_weight *= abs(edge_weight) | |
| important_paths.append({ | |
| 'path': path, | |
| 'weight': path_weight, | |
| 'description': self._describe_path(path) | |
| }) | |
| except nx.NetworkXNoPath: | |
| continue | |
| # Sort paths by importance. | |
| important_paths.sort(key=lambda x: x['weight'], reverse=True) | |
| results = { | |
| 'prompt': prompt, | |
| 'input_tokens': input_tokens, | |
| 'feature_visualizations': feature_visualizations, | |
| 'full_graph_stats': { | |
| 'n_nodes': len(graph.nodes()), | |
| 'n_edges': len(graph.edges()), | |
| 'node_types': dict(self.attribution_graph.node_types) | |
| }, | |
| 'pruned_graph_stats': { | |
| 'n_nodes': len(pruned_graph.nodes()), | |
| 'n_edges': len(pruned_graph.edges()) | |
| }, | |
| 'important_paths': important_paths[:5], | |
| 'graph': pruned_graph | |
| } | |
| return results | |
| def _describe_path(self, path: List[str]) -> str: | |
| # Generates a human-readable description of a path. | |
| descriptions = [] | |
| for node in path: | |
| if self.attribution_graph.node_types[node] == "embedding": | |
| token = node.split('_')[2] | |
| descriptions.append(f"Token '{token}'") | |
| elif self.attribution_graph.node_types[node] == "feature": | |
| parts = node.split('_') | |
| layer = parts[1][1:] | |
| feature = parts[3][1:] | |
| # Try to get the interpretation. | |
| key = f"L{layer}_F{feature}" | |
| interpretation = self.feature_visualizer.feature_interpretations.get(key, "unknown") | |
| descriptions.append(f"Feature L{layer}F{feature} ({interpretation})") | |
| elif self.attribution_graph.node_types[node] == "output": | |
| descriptions.append("Output") | |
| return " β ".join(descriptions) | |
| def save_results(self, results: Dict, save_path: str): | |
| # Saves the analysis results to a file. | |
| # Convert the graph to a serializable format. | |
| serializable_results = results.copy() | |
| if 'graph' in serializable_results: | |
| graph_data = nx.node_link_data(serializable_results['graph']) | |
| serializable_results['graph'] = graph_data | |
| with open(save_path, 'w', encoding='utf-8') as f: | |
| json.dump(serializable_results, f, indent=2, default=str, ensure_ascii=False) | |
| logger.info(f"Results saved to {save_path}") | |
| def save_clt(self, path: str): | |
| # Saves the trained CLT model. | |
| torch.save(self.clt.state_dict(), path) | |
| logger.info(f"CLT model saved to {path}") | |
| def load_clt(self, path: str): | |
| # Loads a trained CLT model. | |
| self.clt.load_state_dict(torch.load(path, map_location=self.device)) | |
| self.clt.to(self.device) | |
| self.clt.eval() | |
| logger.info(f"Loaded CLT model from {path}") | |
| # --- Configuration --- | |
| MAX_SEQ_LEN = 256 | |
| N_FEATURES_PER_LAYER = 512 | |
| TRAINING_STEPS = 2500 | |
| BATCH_SIZE = 64 | |
| LEARNING_RATE = 1e-3 | |
| # German prompts for the final analysis. | |
| ANALYSIS_PROMPTS = [ | |
| "Die Hauptstadt von Frankreich ist", | |
| "def fakultaet(n):", | |
| "Das literarische Stilmittel im Satz 'Der Wind flΓΌsterte durch die BΓ€ume' ist" | |
| ] | |
| # A larger set of German prompts for training. | |
| TRAINING_PROMPTS = [ | |
| "Die Hauptstadt von Frankreich ist", "Sein oder Nichtsein, das ist hier die Frage", "Was du heute kannst besorgen, das verschiebe nicht auf morgen", | |
| "Der erste Mensch auf dem Mond war", "Die chemische Formel fΓΌr Wasser ist H2O.", | |
| "Γbersetze ins Englische: 'Die Katze sitzt auf der Matte.'", "def fakultaet(n):", "import numpy as np", | |
| "Die Hauptzutaten einer Pizza sind", "Was ist das Kraftwerk der Zelle?", | |
| "Die Gleichung E=mc^2 beschreibt die Beziehung zwischen Energie und", "Setze die Geschichte fort: Es war einmal, da war ein", | |
| "Klassifiziere das Sentiment: 'Ich bin ΓΌberglΓΌcklich!'", "Extrahiere die EntitΓ€ten: 'Apple Inc. ist in Cupertino.'", | |
| "Was ist die nΓ€chste Zahl: 2, 4, 8, 16, __?", "Ein rollender Stein setzt kein Moos an", | |
| "Das Gegenteil von heiΓ ist", "import torch", "import pandas as pd", "class MeineKlasse:", | |
| "def __init__(self):", "Die PrimΓ€rfarben sind", "Was ist die Hauptstadt von Japan?", | |
| "Wer hat 'Hamlet' geschrieben?", "Die Quadratwurzel von 64 ist", "Die Sonne geht im Osten auf", | |
| "Der Pazifische Ozean ist der grΓΆΓte Ozean der Erde.", "Die Mitochondrien sind das Kraftwerk der Zelle.", | |
| "Was ist die Hauptstadt der Mongolei?", "Der Film 'Matrix' kann folgendem Genre zugeordnet werden:", | |
| "Die englische Γbersetzung von 'Ich mΓΆchte bitte einen Kaffee bestellen.' lautet:", | |
| "Das literarische Stilmittel im Satz 'Der Wind flΓΌsterte durch die BΓ€ume' ist", | |
| "Eine Python-Funktion, die die FakultΓ€t einer Zahl berechnet, lautet:", | |
| "Die Hauptzutat eines Negroni-Cocktails ist", | |
| "Fasse die Handlung von 'Hamlet' in einem Satz zusammen:", | |
| "Der Satz 'Der Kuchen wurde vom Hund gefressen' steht in folgender Form:", | |
| "Eine gute Γberschrift fΓΌr einen Artikel ΓΌber einen neuen Durchbruch in der Batterietechnologie wΓ€re:" | |
| ] | |
| # --- Qwen API for Feature Interpretation --- | |
| def get_feature_interpretation_with_qwen( | |
| api_config: dict, | |
| top_tokens: list[str], | |
| feature_name: str, | |
| layer_index: int, | |
| max_retries: int = 3, | |
| initial_backoff: float = 2.0 | |
| ) -> str: | |
| # Generates a high-quality interpretation for a feature using the Qwen API. | |
| if not api_config or not api_config.get('api_key'): | |
| logger.warning("Qwen API not configured. Skipping interpretation.") | |
| return "API not configured" | |
| headers = { | |
| "Authorization": f"Bearer {api_config['api_key']}", | |
| "Content-Type": "application/json" | |
| } | |
| # Create a specialized German prompt. | |
| prompt_text = f""" | |
| Sie sind ein Experte fΓΌr die Interpretierbarkeit von Transformern. Ein Merkmal in einem Sprachmodell (Merkmal '{feature_name}' auf Schicht {layer_index}) wird am stΓ€rksten durch die folgenden Token aktiviert: | |
| {', '.join(f"'{token}'" for token in top_tokens)} | |
| Was ist, basierend *nur* auf diesen Token, die wahrscheinlichste Funktion oder Rolle dieses Merkmals? | |
| Ihre Antwort muss ein kurzer, prΓ€gnanter Ausdruck sein (z.B. "Erkennen von Eigennamen", "Identifizieren von JSON-Syntax", "VervollstΓ€ndigen von Listen", "Erkennen negativer Stimmung"). Schreiben Sie keinen ganzen Satz. | |
| """ | |
| data = { | |
| "model": api_config["model"], | |
| "messages": [ | |
| { | |
| "role": "user", | |
| "content": [{"type": "text", "text": prompt_text}] | |
| } | |
| ], | |
| "max_tokens": 50, | |
| "temperature": 0.1, | |
| "top_p": 0.9, | |
| "seed": 42 | |
| } | |
| logger.info(f" > Interpreting {feature_name} (Layer {layer_index})...") | |
| for attempt in range(max_retries): | |
| try: | |
| logger.info(f" - Attempt {attempt + 1}/{max_retries}: Sending request to Qwen API...") | |
| response = requests.post( | |
| f"{api_config['api_endpoint']}/chat/completions", | |
| headers=headers, | |
| json=data, | |
| timeout=60 | |
| ) | |
| response.raise_for_status() | |
| result = response.json() | |
| interpretation = result["choices"][0]["message"]["content"].strip() | |
| # Remove quotes from the output. | |
| if interpretation.startswith('"') and interpretation.endswith('"'): | |
| interpretation = interpretation[1:-1] | |
| logger.info(f" - Success! Interpretation: '{interpretation}'") | |
| return interpretation | |
| except requests.exceptions.RequestException as e: | |
| logger.warning(f" - Qwen API request failed (Attempt {attempt + 1}/{max_retries}): {e}") | |
| if attempt < max_retries - 1: | |
| backoff_time = initial_backoff * (2 ** attempt) | |
| logger.info(f" - Retrying in {backoff_time:.1f} seconds...") | |
| time.sleep(backoff_time) | |
| else: | |
| logger.error(" - Max retries reached. Failing.") | |
| return f"API Error: {e}" | |
| except (KeyError, IndexError) as e: | |
| logger.error(f" - Failed to parse Qwen API response: {e}") | |
| return "API Error: Invalid response format" | |
| finally: | |
| # Add a delay to respect API rate limits. | |
| time.sleep(2.1) | |
| return "API Error: Max retries exceeded" | |
| def train_transcoder(transcoder, model, tokenizer, training_prompts, device, steps=1000, batch_size=16, optimizer=None): | |
| # Trains the Cross-Layer Transcoder. | |
| transcoder.train() | |
| # Use a progress bar for visual feedback. | |
| progress_bar = tqdm(range(steps), desc="Training CLT") | |
| for step in progress_bar: | |
| # Get a random batch of prompts. | |
| batch_prompts = random.choices(training_prompts, k=batch_size) | |
| # Tokenize the batch. | |
| inputs = tokenizer( | |
| batch_prompts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=MAX_SEQ_LEN | |
| ) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Get the model activations. | |
| with torch.no_grad(): | |
| outputs = model(**inputs, output_hidden_states=True) | |
| hidden_states = outputs.hidden_states[1:] | |
| # Forward pass through the CLT. | |
| feature_activations, reconstructed_outputs = transcoder(hidden_states) | |
| # Compute the reconstruction loss. | |
| recon_loss = 0.0 | |
| for i, (target, pred) in enumerate(zip(hidden_states, reconstructed_outputs)): | |
| recon_loss += F.mse_loss(pred, target) | |
| # Compute the sparsity loss. | |
| sparsity_loss = 0.0 | |
| for features in feature_activations: | |
| sparsity_loss += torch.mean(torch.tanh(0.01 * features)) | |
| # Total loss. | |
| loss = (0.8 * recon_loss + 0.2 * sparsity_loss) | |
| if optimizer: | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| progress_bar.set_postfix({ | |
| "Recon Loss": f"{recon_loss.item():.4f}", | |
| "Sparsity Loss": f"{sparsity_loss.item():.4f}", | |
| "Total Loss": f"{loss.item():.4f}" | |
| }) | |
| def generate_feature_visualizations(transcoder, model, tokenizer, prompt, device, qwen_api_config=None, graph_config: Optional[AttributionGraphConfig] = None): | |
| # Generates feature visualizations and interpretations for a prompt. | |
| # Tokenize the prompt. | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=MAX_SEQ_LEN | |
| ) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Get the model activations. | |
| with torch.no_grad(): | |
| outputs = model(**inputs, output_hidden_states=True) | |
| hidden_states = outputs.hidden_states[1:] | |
| # Forward pass through the CLT. | |
| feature_activations, reconstructed_outputs = transcoder(hidden_states) | |
| # Visualize the features. | |
| feature_visualizations = {} | |
| for layer_idx, features in enumerate(feature_activations): | |
| layer_viz = {} | |
| # Analyze the top features for this layer. | |
| # features shape: [batch_size, seq_len, n_features] | |
| feature_importance = torch.mean(features, dim=(0, 1)) | |
| top_features = torch.topk(feature_importance, k=min(5, feature_importance.size(0))).indices | |
| for feat_idx in top_features: | |
| viz = FeatureVisualizer(tokenizer).visualize_feature( | |
| feat_idx.item(), layer_idx, features[0], tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) | |
| ) | |
| interpretation = FeatureVisualizer(tokenizer).interpret_feature( | |
| feat_idx.item(), layer_idx, viz, qwen_api_config | |
| ) | |
| viz['interpretation'] = interpretation | |
| layer_viz[f"feature_{feat_idx.item()}"] = viz | |
| feature_visualizations[f"layer_{layer_idx}"] = layer_viz | |
| # Construct the attribution graph. | |
| if graph_config is None: | |
| graph_config = AttributionGraphConfig() | |
| attribution_graph = AttributionGraph(transcoder, tokenizer, graph_config) | |
| graph = attribution_graph.construct_graph( | |
| tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]), feature_activations, -1 | |
| ) | |
| # Prune the graph. | |
| pruned_graph = attribution_graph.prune_graph(0.8) | |
| # Analyze the most important paths. | |
| important_paths = [] | |
| if len(pruned_graph.nodes()) > 0: | |
| # Find paths from embeddings to the output. | |
| embedding_nodes = [node for node, type_ in attribution_graph.node_types.items() | |
| if type_ == "embedding" and node in pruned_graph] | |
| output_nodes = [node for node, type_ in attribution_graph.node_types.items() | |
| if type_ == "output" and node in pruned_graph] | |
| for emb_node in embedding_nodes[:3]: | |
| for out_node in output_nodes: | |
| try: | |
| paths = list(nx.all_simple_paths(pruned_graph, emb_node, out_node, cutoff=5)) | |
| for path in paths[:2]: | |
| path_weight = 1.0 | |
| for i in range(len(path) - 1): | |
| edge_weight = attribution_graph.edge_weights.get( | |
| (path[i], path[i+1]), 0.0 | |
| ) | |
| path_weight *= abs(edge_weight) | |
| important_paths.append({ | |
| 'path': path, | |
| 'weight': path_weight, | |
| 'description': attribution_graph._describe_path(path) | |
| }) | |
| except nx.NetworkXNoPath: | |
| continue | |
| # Sort paths by importance. | |
| important_paths.sort(key=lambda x: x['weight'], reverse=True) | |
| return { | |
| "prompt": prompt, | |
| "full_graph_stats": { | |
| "n_nodes": len(graph.nodes()), | |
| "n_edges": len(graph.edges()), | |
| "node_types": dict(attribution_graph.node_types) | |
| }, | |
| "pruned_graph_stats": { | |
| "n_nodes": len(pruned_graph.nodes()), | |
| "n_edges": len(pruned_graph.edges()) | |
| }, | |
| "feature_visualizations": feature_visualizations, | |
| "important_paths": important_paths[:5] | |
| } | |
| def main(): | |
| # Main function to run the analysis for a single prompt. | |
| # Set a seed for reproducibility. | |
| set_seed() | |
| # --- Argument Parser --- | |
| parser = argparse.ArgumentParser(description="Run Attribution Graph analysis for a single prompt.") | |
| parser.add_argument( | |
| '--prompt-index', | |
| type=int, | |
| required=True, | |
| help=f"The 0-based index of the prompt to analyze from the ANALYSIS_PROMPTS list (0 to {len(ANALYSIS_PROMPTS) - 1})." | |
| ) | |
| parser.add_argument( | |
| '--force-retrain-clt', | |
| action='store_true', | |
| help="Force re-training of the Cross-Layer Transcoder, even if a saved model exists." | |
| ) | |
| args = parser.parse_args() | |
| prompt_idx = args.prompt_index | |
| if not (0 <= prompt_idx < len(ANALYSIS_PROMPTS)): | |
| print(f"β Error: --prompt-index must be between 0 and {len(ANALYSIS_PROMPTS) - 1}.") | |
| return | |
| # Get the API config from the utility function. | |
| qwen_api_config = init_qwen_api() | |
| # Configuration | |
| config = AttributionGraphConfig( | |
| model_path="./models/OLMo-2-1124-7B", | |
| n_features_per_layer=512, | |
| training_steps=500, | |
| batch_size=4, | |
| max_seq_length=256, | |
| learning_rate=1e-4, | |
| sparsity_lambda=0.01, | |
| qwen_api_config=qwen_api_config | |
| ) | |
| print("Attribution Graphs for OLMo2 7B - Single Prompt Pipeline (German)") | |
| print("=" * 50) | |
| print(f"Model path: {config.model_path}") | |
| print(f"Device: {config.device}") | |
| try: | |
| # Initialize the full pipeline. | |
| print("π Initializing Attribution Graphs Pipeline...") | |
| pipeline = AttributionGraphsPipeline(config) | |
| print("β Pipeline initialized successfully") | |
| print() | |
| # Load an existing CLT model or train a new one. | |
| if os.path.exists(CLT_SAVE_PATH) and not args.force_retrain_clt: | |
| print(f"π§ Loading existing CLT model from {CLT_SAVE_PATH}...") | |
| pipeline.load_clt(CLT_SAVE_PATH) | |
| print("β CLT model loaded successfully.") | |
| else: | |
| if args.force_retrain_clt and os.path.exists(CLT_SAVE_PATH): | |
| print("πββοΈ --force-retrain-clt flag is set. Overwriting existing model.") | |
| # Train a new CLT model. | |
| print("π Training a new CLT model...") | |
| print(f" Training on {len(TRAINING_PROMPTS)} example texts...") | |
| training_stats = pipeline.train_clt(TRAINING_PROMPTS) | |
| print("β CLT training completed.") | |
| # Save the new model. | |
| pipeline.save_clt(CLT_SAVE_PATH) | |
| print(f" Saved trained model to {CLT_SAVE_PATH} for future use.") | |
| print() | |
| # Analyze the selected prompt. | |
| prompt_to_analyze = ANALYSIS_PROMPTS[prompt_idx] | |
| print(f"π Analyzing prompt {prompt_idx + 1}/{len(ANALYSIS_PROMPTS)}: '{prompt_to_analyze}'") | |
| analysis = pipeline.analyze_prompt(prompt_to_analyze, target_token_idx=-1) | |
| # Display the key results. | |
| print(f" β Tokenized into {len(analysis['input_tokens'])} tokens") | |
| print(f" β Full graph: {analysis['full_graph_stats']['n_nodes']} nodes, {analysis['full_graph_stats']['n_edges']} edges") | |
| print(f" β Pruned graph: {analysis['pruned_graph_stats']['n_nodes']} nodes, {analysis['pruned_graph_stats']['n_edges']} edges") | |
| # Show the top features. | |
| print(" π Top active features:") | |
| for layer_name, layer_features in list(analysis['feature_visualizations'].items())[:3]: | |
| print(f" {layer_name}:") | |
| for feat_name, feat_data in list(layer_features.items())[:2]: | |
| print(f" {feat_name}: {feat_data['interpretation']} (max: {feat_data['max_activation']:.3f})") | |
| print() | |
| # Run a perturbation experiment. | |
| print("π§ͺ Running perturbation experiment...") | |
| # (No need to pass training_stats to the experiment) | |
| if analysis['feature_visualizations']: | |
| first_layer_key = next(iter(analysis['feature_visualizations']), None) | |
| if first_layer_key: | |
| layer_idx = int(first_layer_key.split('_')[1]) | |
| first_feature_key = next(iter(analysis['feature_visualizations'][first_layer_key]), None) | |
| if first_feature_key: | |
| feature_idx = int(first_feature_key.split('_')[1]) | |
| ablation_result = pipeline.perturbation_experiments.feature_ablation_experiment( | |
| prompt_to_analyze, layer_idx, feature_idx, intervention_strength=3.0 | |
| ) | |
| print(f" Ablated L{layer_idx}F{feature_idx}: Ξ probability = {ablation_result['probability_change']:.4f}") | |
| print("β Perturbation experiment completed") | |
| print() | |
| # Generate a visualization for the prompt. | |
| print("π Generating visualization...") | |
| if 'graph' in analysis and analysis['pruned_graph_stats']['n_nodes'] > 0: | |
| viz_path = os.path.join(RESULTS_DIR, f"attribution_graph_prompt_de_{prompt_idx + 1}.png") | |
| pipeline.attribution_graph.visualize_graph(analysis['graph'], save_path=viz_path) | |
| print(f" β Graph visualization saved to {viz_path}") | |
| else: | |
| print(" - Skipping visualization as no graph was generated or it was empty.") | |
| # Save the results in a format for the web app. | |
| save_path = os.path.join(RESULTS_DIR, f"attribution_graphs_results_de_prompt_{prompt_idx + 1}.json") | |
| # Create a JSON file that can be merged with others. | |
| final_results = { | |
| "analyses": { | |
| f"prompt_de_{prompt_idx + 1}": analysis | |
| }, | |
| "config": config.__dict__, | |
| "timestamp": str(time.time()) | |
| } | |
| # The web page doesn't use the graph object, so remove it. | |
| if 'graph' in final_results['analyses'][f"prompt_de_{prompt_idx + 1}"]: | |
| del final_results['analyses'][f"prompt_de_{prompt_idx + 1}"]['graph'] | |
| pipeline.save_results(final_results, save_path) | |
| print(f"πΎ Results saved to {save_path}") | |
| print("\nπ Analysis for this prompt completed successfully!") | |
| except Exception as e: | |
| print(f"β Error during execution: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| if __name__ == "__main__": | |
| main() |