#!/usr/bin/env python3 # This script generates attribution graphs for the OLMo2 7B model. import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import matplotlib.pyplot as plt import seaborn as sns from typing import Dict, List, Tuple, Optional, Any, Set import json import logging from pathlib import Path from transformers import AutoTokenizer, AutoModelForCausalLM from collections import defaultdict import networkx as nx from dataclasses import dataclass from tqdm import tqdm import pickle import requests import time import random import copy import os import argparse # --- Add this block to fix the import path --- import sys from pathlib import Path sys.path.append(str(Path(__file__).resolve().parent.parent)) # --------------------------------------------- from utilities.utils import init_qwen_api, set_seed # --- Constants --- RESULTS_DIR = "circuit_analysis/results" CLT_SAVE_PATH = "circuit_analysis/models/clt_model.pth" # Configure logging. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Set the device for training. if torch.backends.mps.is_available(): DEVICE = torch.device("mps") logger.info("Using MPS (Metal Performance Shaders) for GPU acceleration") elif torch.cuda.is_available(): DEVICE = torch.device("cuda") logger.info("Using CUDA for GPU acceleration") else: DEVICE = torch.device("cpu") logger.info("Using CPU") @dataclass class AttributionGraphConfig: # Configuration for building the attribution graph. model_path: str = "./models/OLMo-2-1124-7B" max_seq_length: int = 512 n_features_per_layer: int = 512 # Number of features in each CLT layer sparsity_lambda: float = 1e-3 # Updated for L1 sparsity reconstruction_loss_weight: float = 1.0 batch_size: int = 8 learning_rate: float = 1e-4 training_steps: int = 1000 device: str = str(DEVICE) pruning_threshold: float = 0.8 # For graph pruning intervention_strength: float = 5.0 # For perturbation experiments qwen_api_config: Optional[Dict[str, str]] = None max_ablation_experiments: Optional[int] = None ablation_top_k_tokens: int = 5 ablation_features_per_layer: Optional[int] = 2 summary_max_layers: Optional[int] = None summary_features_per_layer: Optional[int] = 2 random_baseline_trials: int = 5 random_baseline_features: int = 1 random_baseline_seed: int = 1234 path_ablation_top_k: int = 3 random_path_baseline_trials: int = 5 graph_max_features_per_layer: int = 40 graph_feature_activation_threshold: float = 0.01 graph_edge_weight_threshold: float = 0.0 graph_max_edges_per_node: int = 12 class JumpReLU(nn.Module): # The JumpReLU activation function. def __init__(self, threshold: float = 0.0): super().__init__() self.threshold = threshold def forward(self, x): return F.relu(x - self.threshold) class CrossLayerTranscoder(nn.Module): # The Cross-Layer Transcoder (CLT) model. def __init__(self, model_config: Dict, clt_config: AttributionGraphConfig): super().__init__() self.config = clt_config self.model_config = model_config self.n_layers = model_config['num_hidden_layers'] self.hidden_size = model_config['hidden_size'] self.n_features = clt_config.n_features_per_layer # Encoder weights for each layer. self.encoders = nn.ModuleList([ nn.Linear(self.hidden_size, self.n_features, bias=False) for _ in range(self.n_layers) ]) # Decoder weights for cross-layer connections. self.decoders = nn.ModuleDict() for source_layer in range(self.n_layers): for target_layer in range(source_layer, self.n_layers): key = f"{source_layer}_to_{target_layer}" self.decoders[key] = nn.Linear(self.n_features, self.hidden_size, bias=False) # The activation function. self.activation = JumpReLU(threshold=0.0) # Initialize the weights. self._init_weights() def _init_weights(self): # Initializes the weights with small random values. for module in self.modules(): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.01) def encode(self, layer_idx: int, residual_activations: torch.Tensor) -> torch.Tensor: # Encodes residual stream activations to feature activations. return self.activation(self.encoders[layer_idx](residual_activations)) def decode(self, source_layer: int, target_layer: int, feature_activations: torch.Tensor) -> torch.Tensor: # Decodes feature activations to the MLP output space. key = f"{source_layer}_to_{target_layer}" return self.decoders[key](feature_activations) def forward(self, residual_activations: List[torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: # The forward pass of the CLT. feature_activations = [] reconstructed_mlp_outputs = [] # Encode features for each layer. for layer_idx, residual in enumerate(residual_activations): features = self.encode(layer_idx, residual) feature_activations.append(features) # Reconstruct MLP outputs with cross-layer connections. for target_layer in range(self.n_layers): reconstruction = torch.zeros_like(residual_activations[target_layer]) # Sum contributions from all previous layers. for source_layer in range(target_layer + 1): decoded = self.decode(source_layer, target_layer, feature_activations[source_layer]) reconstruction += decoded reconstructed_mlp_outputs.append(reconstruction) return feature_activations, reconstructed_mlp_outputs class FeatureVisualizer: # A class to visualize and interpret individual features. def __init__(self, tokenizer, cache_dir: Optional[Path] = None): self.tokenizer = tokenizer self.feature_interpretations: Dict[str, str] = {} self.cache_dir = cache_dir if self.cache_dir is not None: self.cache_dir = Path(self.cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) self._load_cache() def _cache_file(self) -> Optional[Path]: if self.cache_dir is None: return None return self.cache_dir / "feature_interpretations.json" def _load_cache(self): cache_file = self._cache_file() if cache_file is None or not cache_file.exists(): return try: with open(cache_file, 'r', encoding='utf-8') as f: data = json.load(f) if isinstance(data, dict): self.feature_interpretations.update({str(k): str(v) for k, v in data.items()}) except Exception as e: logger.warning(f"Failed to load feature interpretation cache: {e}") def _save_cache(self): cache_file = self._cache_file() if cache_file is None: return try: with open(cache_file, 'w', encoding='utf-8') as f: json.dump(self.feature_interpretations, f, indent=2) except Exception as e: logger.warning(f"Failed to save feature interpretation cache: {e}") def visualize_feature(self, feature_idx: int, layer_idx: int, activations: torch.Tensor, input_tokens: List[str], top_k: int = 10) -> Dict: # Creates a visualization for a single feature. feature_acts = activations[:, feature_idx].detach().cpu().numpy() # Find the top activating positions. top_positions = np.argsort(feature_acts)[-top_k:][::-1] visualization = { 'feature_idx': feature_idx, 'layer_idx': layer_idx, 'max_activation': float(feature_acts.max()), 'mean_activation': float(feature_acts.mean()), 'sparsity': float((feature_acts > 0.1).mean()), 'top_activations': [] } for pos in top_positions: if pos < len(input_tokens): visualization['top_activations'].append({ 'token': input_tokens[pos], 'position': int(pos), 'activation': float(feature_acts[pos]) }) return visualization def interpret_feature(self, feature_idx: int, layer_idx: int, visualization_data: Dict, qwen_api_config: Optional[Dict[str, str]] = None) -> str: # Interprets a feature based on its top activating tokens. top_tokens = [item['token'] for item in visualization_data['top_activations']] cache_key = f"L{layer_idx}_F{feature_idx}" if cache_key in self.feature_interpretations: return self.feature_interpretations[cache_key] # Use the Qwen API if it is configured. if qwen_api_config and qwen_api_config.get('api_key'): feature_name = cache_key interpretation = get_feature_interpretation_with_qwen( qwen_api_config, top_tokens, feature_name, layer_idx ) else: # Use a simple heuristic as a fallback. if len(set(top_tokens)) == 1 and top_tokens: interpretation = f"Specific token: '{top_tokens[0]}'" elif top_tokens and all(token.isalpha() for token in top_tokens): interpretation = "Word/alphabetic tokens" elif top_tokens and all(token.isdigit() for token in top_tokens): interpretation = "Numeric tokens" elif top_tokens and all(token in '.,!?;:' for token in top_tokens): interpretation = "Punctuation" else: interpretation = "Mixed/polysemantic feature" self.feature_interpretations[cache_key] = interpretation self._save_cache() return interpretation class AttributionGraph: # A class to construct and analyze attribution graphs. def __init__(self, clt: CrossLayerTranscoder, tokenizer, config: AttributionGraphConfig): self.clt = clt self.tokenizer = tokenizer self.config = config self.graph = nx.DiGraph() self.node_types = {} # Track node types (feature, embedding, error, output) self.edge_weights = {} self.feature_metadata: Dict[str, Dict[str, Any]] = {} def compute_virtual_weights(self, source_layer: int, target_layer: int, source_feature: int, target_feature: int) -> float: # Computes the virtual weight between two features. if target_layer <= source_layer: return 0.0 # Get the encoder and decoder weights. encoder_weight = self.clt.encoders[target_layer].weight[target_feature] # [hidden_size] total_weight = 0.0 for intermediate_layer in range(source_layer, target_layer): decoder_key = f"{source_layer}_to_{intermediate_layer}" if decoder_key in self.clt.decoders: decoder_weight = self.clt.decoders[decoder_key].weight[:, source_feature] # [hidden_size] # The virtual weight is inner product virtual_weight = torch.dot(decoder_weight, encoder_weight).item() total_weight += virtual_weight return total_weight def construct_graph(self, input_tokens: List[str], feature_activations: List[torch.Tensor], target_token_idx: int = -1) -> nx.DiGraph: # Constructs the attribution graph for a prompt. self.graph.clear() self.node_types.clear() self.edge_weights.clear() seq_len = len(input_tokens) n_layers = len(feature_activations) # Add embedding nodes for the input tokens. for i, token in enumerate(input_tokens): node_id = f"emb_{i}_{token}" self.graph.add_node(node_id) self.node_types[node_id] = "embedding" # Add nodes for the features. active_features = {} # Track which features are significantly active max_features_per_layer = self.config.graph_max_features_per_layer or 20 # Limit features per layer to prevent explosion activation_threshold = self.config.graph_feature_activation_threshold edge_weight_threshold = self.config.graph_edge_weight_threshold max_edges_per_node_cfg = self.config.graph_max_edges_per_node or 5 for layer_idx, features in enumerate(feature_activations): # features shape: [batch_size, seq_len, n_features] batch_size, seq_len_layer, n_features = features.shape # Get the top activating features for this layer. layer_activations = features[0].mean(dim=0) # Average across sequence top_features = torch.topk(layer_activations, k=min(max_features_per_layer, n_features)).indices for token_pos in range(min(seq_len, seq_len_layer)): for feat_idx in top_features: activation = features[0, token_pos, feat_idx.item()].item() if activation > activation_threshold: node_id = f"feat_L{layer_idx}_T{token_pos}_F{feat_idx.item()}" self.graph.add_node(node_id) self.node_types[node_id] = "feature" active_features[node_id] = { 'layer': layer_idx, 'token_pos': token_pos, 'feature_idx': feat_idx.item(), 'activation': activation } self.feature_metadata[node_id] = { 'layer': layer_idx, 'token_position': token_pos, 'feature_index': feat_idx.item(), 'activation': activation, 'input_token': input_tokens[token_pos] if token_pos < len(input_tokens) else None } # Add an output node for the target token. output_node = f"output_{target_token_idx}" self.graph.add_node(output_node) self.node_types[output_node] = "output" # Add edges based on virtual weights and activations. feature_nodes = [node for node, type_ in self.node_types.items() if type_ == "feature"] print(f" Building attribution graph: {len(feature_nodes)} feature nodes, {len(self.graph.nodes())} total nodes") # Limit the number of edges to compute. max_edges_per_node = max(max_edges_per_node_cfg, 1) # Limit connections per node for i, source_node in enumerate(feature_nodes): if i % 50 == 0: # Progress indicator print(f" Processing node {i+1}/{len(feature_nodes)}") edges_added = 0 source_info = active_features[source_node] source_activation = source_info['activation'] # Add edges to other features. for target_node in feature_nodes: if source_node == target_node or edges_added >= max_edges_per_node: continue target_info = active_features[target_node] # Only add edges that go forward in the network. if (target_info['layer'] > source_info['layer'] or (target_info['layer'] == source_info['layer'] and target_info['token_pos'] > source_info['token_pos'])): virtual_weight = self.compute_virtual_weights( source_info['layer'], target_info['layer'], source_info['feature_idx'], target_info['feature_idx'] ) if abs(virtual_weight) > edge_weight_threshold: edge_weight = source_activation * virtual_weight self.graph.add_edge(source_node, target_node, weight=edge_weight) self.edge_weights[(source_node, target_node)] = edge_weight edges_added += 1 # Add edges to the output node. layer_position = source_info['layer'] # Allow contributions from all layers, with smaller weights for early layers. layer_scale = 0.1 if layer_position >= n_layers - 2 else max(0.05, 0.1 * (layer_position + 1) / n_layers) output_weight = source_activation * layer_scale if abs(output_weight) > 0: self.graph.add_edge(source_node, output_node, weight=output_weight) self.edge_weights[(source_node, output_node)] = output_weight # Add edges from embeddings to early features. for emb_node in [node for node, type_ in self.node_types.items() if type_ == "embedding"]: token_idx = int(emb_node.split('_')[1]) for feat_node in feature_nodes: feat_info = active_features[feat_node] if feat_info['layer'] == 0 and feat_info['token_pos'] == token_idx: # Direct connection from an embedding to a first-layer feature. weight = feat_info['activation'] * 0.5 # Simplified self.graph.add_edge(emb_node, feat_node, weight=weight) self.edge_weights[(emb_node, feat_node)] = weight return self.graph def prune_graph(self, threshold: float = 0.8) -> nx.DiGraph: # Prunes the graph to keep only the most important nodes. # Calculate node importance based on edge weights. node_importance = defaultdict(float) for (source, target), weight in self.edge_weights.items(): node_importance[source] += abs(weight) node_importance[target] += abs(weight) # Keep the top nodes by importance. sorted_nodes = sorted(node_importance.items(), key=lambda x: x[1], reverse=True) n_keep = int(len(sorted_nodes) * threshold) important_nodes = set([node for node, _ in sorted_nodes[:n_keep]]) # Always keep the output and embedding nodes. for node, type_ in self.node_types.items(): if type_ in ["output", "embedding"]: important_nodes.add(node) # Create the pruned graph. pruned_graph = self.graph.subgraph(important_nodes).copy() return pruned_graph def visualize_graph(self, graph: nx.DiGraph = None, save_path: str = None): # Visualizes the attribution graph. if graph is None: graph = self.graph plt.figure(figsize=(12, 8)) # Create a layout for the graph. pos = nx.spring_layout(graph, k=1, iterations=50) # Color the nodes by type. node_colors = [] for node in graph.nodes(): node_type = self.node_types.get(node, "unknown") if node_type == "embedding": node_colors.append('lightblue') elif node_type == "feature": node_colors.append('lightgreen') elif node_type == "output": node_colors.append('orange') else: node_colors.append('gray') # Draw the nodes. nx.draw_networkx_nodes(graph, pos, node_color=node_colors, node_size=300, alpha=0.8) # Draw the edges with thickness based on weight. edges = graph.edges() edge_weights = [abs(self.edge_weights.get((u, v), 0.1)) for u, v in edges] max_weight = max(edge_weights) if edge_weights else 1 edge_widths = [w / max_weight * 3 for w in edge_weights] nx.draw_networkx_edges(graph, pos, width=edge_widths, alpha=0.6, edge_color='gray', arrows=True) # Draw the labels. nx.draw_networkx_labels(graph, pos, font_size=8) plt.title("Attribution Graph") plt.axis('off') if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.show() class PerturbationExperiments: # Conducts perturbation experiments to validate hypotheses. def __init__(self, model, clt: CrossLayerTranscoder, tokenizer): self.model = model self.clt = clt self.tokenizer = tokenizer self._transformer_blocks: Optional[List[nn.Module]] = None def _get_transformer_blocks(self) -> List[nn.Module]: if self._transformer_blocks is not None: return self._transformer_blocks n_layers = getattr(self.model.config, "num_hidden_layers", None) if n_layers is None: raise ValueError("Model config does not expose num_hidden_layers; cannot resolve transformer blocks.") candidate_lists: List[Tuple[str, nn.ModuleList]] = [] for name, module in self.model.named_modules(): if isinstance(module, nn.ModuleList) and len(module) == n_layers: candidate_lists.append((name, module)) if not candidate_lists: raise ValueError("Unable to locate transformer block ModuleList in model.") # Prefer names that look like transformer blocks. def _score(name: str) -> Tuple[int, str]: preferred_suffixes = ("layers", "blocks", "h") for idx, suffix in enumerate(preferred_suffixes): if name.endswith(suffix): return (idx, name) return (len(preferred_suffixes), name) selected_name, selected_list = sorted(candidate_lists, key=lambda item: _score(item[0]))[0] self._transformer_blocks = list(selected_list) logger.debug(f"Resolved transformer blocks from ModuleList '{selected_name}'.") return self._transformer_blocks def _format_top_tokens(self, top_tokens: torch.return_types.topk) -> List[Tuple[str, float]]: return [ (self.tokenizer.decode([idx]), prob.item()) for idx, prob in zip(top_tokens.indices, top_tokens.values) ] def _prepare_inputs(self, input_text: str, top_k: int) -> Dict[str, Any]: if torch.backends.mps.is_available(): torch.mps.empty_cache() device = next(self.model.parameters()).device inputs = self.tokenizer( input_text, return_tensors="pt", padding=True, truncation=True, max_length=512 ) if inputs["input_ids"].size(0) != 1: raise ValueError("Perturbation experiments currently support only batch size 1.") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): baseline_outputs = self.model(**inputs, output_hidden_states=True, return_dict=True) baseline_logits = baseline_outputs.logits[0] target_position = baseline_logits.size(0) - 1 baseline_last_token_logits = baseline_logits[target_position] baseline_probs = F.softmax(baseline_last_token_logits, dim=-1) baseline_top_tokens = torch.topk(baseline_probs, k=top_k) hidden_states: List[torch.Tensor] = list(baseline_outputs.hidden_states[1:]) with torch.no_grad(): feature_activations, _ = self.clt(hidden_states) return { 'inputs': inputs, 'baseline_outputs': baseline_outputs, 'baseline_logits': baseline_logits, 'baseline_last_token_logits': baseline_last_token_logits, 'baseline_probs': baseline_probs, 'baseline_top_tokens': baseline_top_tokens, 'target_position': target_position, 'hidden_states': hidden_states, 'feature_activations': feature_activations, 'default_target_token_id': baseline_top_tokens.indices[0].item() } def _compute_feature_contributions( self, feature_activations: List[torch.Tensor], feature_set: List[Tuple[int, int]] ) -> Dict[int, torch.Tensor]: contributions: Dict[int, torch.Tensor] = {} with torch.no_grad(): for layer_idx, feature_idx in feature_set: if layer_idx >= len(feature_activations): continue features = feature_activations[layer_idx] if feature_idx >= features.size(-1): continue feature_values = features[:, :, feature_idx].detach() for dest_layer in range(layer_idx, self.clt.n_layers): decoder_key = f"{layer_idx}_to_{dest_layer}" if decoder_key not in self.clt.decoders: continue decoder = self.clt.decoders[decoder_key] weight_column = decoder.weight[:, feature_idx] contrib = torch.einsum('bs,h->bsh', feature_values, weight_column).detach() if dest_layer in contributions: contributions[dest_layer] += contrib else: contributions[dest_layer] = contrib return contributions def _run_with_hooks( self, inputs: Dict[str, torch.Tensor], contributions: Dict[int, torch.Tensor], intervention_strength: float ): blocks = self._get_transformer_blocks() handles: List[Any] = [] def _make_hook(cached_contrib: torch.Tensor): def hook(module, module_input, module_output): if isinstance(module_output, torch.Tensor): target_tensor = module_output elif isinstance(module_output, (tuple, list)): target_tensor = module_output[0] elif hasattr(module_output, "last_hidden_state"): target_tensor = module_output.last_hidden_state else: raise TypeError( f"Unsupported module output type '{type(module_output)}' for perturbation hook." ) tensor_contrib = cached_contrib.to(target_tensor.device).to(target_tensor.dtype) scaled = intervention_strength * tensor_contrib if isinstance(module_output, torch.Tensor): return module_output - scaled elif isinstance(module_output, tuple): modified = module_output[0] - scaled return (modified,) + tuple(module_output[1:]) elif isinstance(module_output, list): modified = [module_output[0] - scaled, *module_output[1:]] return modified else: module_output.last_hidden_state = module_output.last_hidden_state - scaled return module_output return hook try: for dest_layer, contrib in contributions.items(): if dest_layer >= len(blocks): continue handles.append(blocks[dest_layer].register_forward_hook(_make_hook(contrib))) with torch.no_grad(): outputs = self.model(**inputs, output_hidden_states=True, return_dict=True) finally: for handle in handles: handle.remove() return outputs def feature_set_ablation_experiment( self, input_text: str, feature_set: List[Tuple[int, int]], intervention_strength: float = 5.0, target_token_id: Optional[int] = None, top_k: int = 5, ablation_label: str = "feature_set", extra_metadata: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: try: baseline_data = self._prepare_inputs(input_text, top_k) if target_token_id is None: target_token_id = baseline_data['default_target_token_id'] feature_set_normalized = [ (int(layer_idx), int(feature_idx)) for layer_idx, feature_idx in feature_set ] contributions = self._compute_feature_contributions( baseline_data['feature_activations'], feature_set_normalized ) baseline_probs = baseline_data['baseline_probs'] baseline_top_tokens = baseline_data['baseline_top_tokens'] baseline_last_token_logits = baseline_data['baseline_last_token_logits'] target_position = baseline_data['target_position'] hidden_states = baseline_data['hidden_states'] baseline_prob = baseline_probs[target_token_id].item() baseline_logit = baseline_last_token_logits[target_token_id].item() baseline_summary = { 'baseline_top_tokens': self._format_top_tokens(baseline_top_tokens), 'baseline_probability': baseline_prob, 'baseline_logit': baseline_logit } if not contributions: result = { **baseline_summary, 'ablated_top_tokens': baseline_summary['baseline_top_tokens'], 'ablated_probability': baseline_prob, 'ablated_logit': baseline_logit, 'probability_change': 0.0, 'logit_change': 0.0, 'kl_divergence': 0.0, 'entropy_change': 0.0, 'hidden_state_delta_norm': 0.0, 'hidden_state_relative_change': 0.0, 'ablation_flips_top_prediction': False, 'feature_set': [ {'layer': layer_idx, 'feature': feature_idx} for layer_idx, feature_idx in feature_set_normalized ], 'feature_set_size': len(feature_set_normalized), 'intervention_strength': intervention_strength, 'target_token_id': target_token_id, 'target_token': self.tokenizer.decode([target_token_id]), 'contributing_layers': [], 'ablation_applied': False, 'ablation_type': ablation_label, 'warning': 'no_contributions_found' } if extra_metadata: result.update(extra_metadata) return result ablated_outputs = self._run_with_hooks( baseline_data['inputs'], contributions, intervention_strength ) ablated_logits = ablated_outputs.logits[0, target_position] ablated_probs = F.softmax(ablated_logits, dim=-1) ablated_top_tokens = torch.topk(ablated_probs, k=top_k) ablated_prob = ablated_probs[target_token_id].item() ablated_logit = ablated_logits[target_token_id].item() epsilon = 1e-9 kl_divergence = torch.sum( baseline_probs * (torch.log(baseline_probs + epsilon) - torch.log(ablated_probs + epsilon)) ).item() if not np.isfinite(kl_divergence): kl_divergence = 0.0 entropy_baseline = -(baseline_probs * torch.log(baseline_probs + epsilon)).sum().item() entropy_ablated = -(ablated_probs * torch.log(ablated_probs + epsilon)).sum().item() entropy_change = entropy_ablated - entropy_baseline if not np.isfinite(entropy_change): entropy_change = 0.0 baseline_hidden = hidden_states[-1][:, target_position, :] ablated_hidden = ablated_outputs.hidden_states[-1][:, target_position, :] hidden_delta_norm = torch.norm(baseline_hidden - ablated_hidden, dim=-1).item() hidden_baseline_norm = torch.norm(baseline_hidden, dim=-1).item() hidden_relative_change = hidden_delta_norm / (hidden_baseline_norm + 1e-9) result = { **baseline_summary, 'ablated_top_tokens': self._format_top_tokens(ablated_top_tokens), 'ablated_probability': ablated_prob, 'ablated_logit': ablated_logit, 'probability_change': baseline_prob - ablated_prob, 'logit_change': baseline_logit - ablated_logit, 'kl_divergence': kl_divergence, 'entropy_change': entropy_change, 'hidden_state_delta_norm': hidden_delta_norm, 'hidden_state_relative_change': hidden_relative_change, 'ablation_flips_top_prediction': bool( baseline_top_tokens.indices[0].item() != ablated_top_tokens.indices[0].item() ), 'feature_set': [ {'layer': layer_idx, 'feature': feature_idx} for layer_idx, feature_idx in feature_set_normalized ], 'feature_set_size': len(feature_set_normalized), 'intervention_strength': intervention_strength, 'target_token_id': target_token_id, 'target_token': self.tokenizer.decode([target_token_id]), 'contributing_layers': sorted(list(contributions.keys())), 'ablation_applied': True, 'ablation_type': ablation_label } if extra_metadata: result.update(extra_metadata) return result except Exception as e: logger.warning(f"Perturbation experiment failed: {e}") return { 'baseline_top_tokens': [], 'ablated_top_tokens': [], 'feature_set': [ {'layer': layer_idx, 'feature': feature_idx} for layer_idx, feature_idx in feature_set ], 'feature_set_size': len(feature_set), 'intervention_strength': intervention_strength, 'probability_change': 0.0, 'logit_change': 0.0, 'kl_divergence': 0.0, 'entropy_change': 0.0, 'hidden_state_delta_norm': 0.0, 'hidden_state_relative_change': 0.0, 'ablation_flips_top_prediction': False, 'ablation_applied': False, 'ablation_type': ablation_label, 'error': str(e) } def feature_ablation_experiment( self, input_text: str, target_layer: int, target_feature: int, intervention_strength: float = 5.0, target_token_id: Optional[int] = None, top_k: int = 5, ) -> Dict[str, Any]: return self.feature_set_ablation_experiment( input_text=input_text, feature_set=[(target_layer, target_feature)], intervention_strength=intervention_strength, target_token_id=target_token_id, top_k=top_k, ablation_label="targeted_feature" ) def random_feature_ablation_experiment( self, input_text: str, num_features: int = 1, intervention_strength: float = 5.0, target_token_id: Optional[int] = None, top_k: int = 5, seed: Optional[int] = None ) -> Dict[str, Any]: rng = random.Random(seed) num_features = max(1, int(num_features)) feature_set: List[Tuple[int, int]] = [] for _ in range(num_features): layer_idx = rng.randrange(self.clt.n_layers) feature_idx = rng.randrange(self.clt.n_features) feature_set.append((layer_idx, feature_idx)) result = self.feature_set_ablation_experiment( input_text=input_text, feature_set=feature_set, intervention_strength=intervention_strength, target_token_id=target_token_id, top_k=top_k, ablation_label="random_baseline", extra_metadata={'random_seed': seed} ) return result class AttributionGraphsPipeline: # The main pipeline for the attribution graph analysis. def __init__(self, config: AttributionGraphConfig): self.config = config self.device = torch.device(config.device) # Load the model and tokenizer. logger.info(f"Loading OLMo2 7B model from {config.model_path}") self.tokenizer = AutoTokenizer.from_pretrained(config.model_path) # Configure model loading based on the device. if "mps" in config.device: # MPS supports float16 but not device_map. self.model = AutoModelForCausalLM.from_pretrained( config.model_path, torch_dtype=torch.float16, device_map=None ).to(self.device) elif "cuda" in config.device: self.model = AutoModelForCausalLM.from_pretrained( config.model_path, torch_dtype=torch.float16, device_map="auto" ) else: # CPU self.model = AutoModelForCausalLM.from_pretrained( config.model_path, torch_dtype=torch.float32, device_map=None ).to(self.device) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Initialize the CLT. model_config = self.model.config.to_dict() self.clt = CrossLayerTranscoder(model_config, config).to(self.device) # Initialize the other components. # cache_dir = Path(RESULTS_DIR) / "feature_interpretations_cache" # Disable persistent caching to ensure interpretations are prompt-specific and not reused from other contexts. self.feature_visualizer = FeatureVisualizer(self.tokenizer, cache_dir=None) self.attribution_graph = AttributionGraph(self.clt, self.tokenizer, config) self.perturbation_experiments = PerturbationExperiments(self.model, self.clt, self.tokenizer) logger.info("Attribution Graphs Pipeline initialized successfully") def train_clt(self, training_texts: List[str]) -> Dict: # Trains the Cross-Layer Transcoder. logger.info("Starting CLT training...") optimizer = torch.optim.Adam(self.clt.parameters(), lr=self.config.learning_rate) training_stats = { 'reconstruction_losses': [], 'sparsity_losses': [], 'total_losses': [] } for step in tqdm(range(self.config.training_steps), desc="Training CLT"): # Sample a batch of texts. batch_texts = np.random.choice(training_texts, size=self.config.batch_size) total_loss = 0.0 total_recon_loss = 0.0 total_sparsity_loss = 0.0 for text in batch_texts: # Tokenize the text. inputs = self.tokenizer(text, return_tensors="pt", max_length=self.config.max_seq_length, truncation=True, padding=True).to(self.device) # Get the model activations. with torch.no_grad(): outputs = self.model(**inputs, output_hidden_states=True) hidden_states = outputs.hidden_states[1:] # Forward pass through the CLT. feature_activations, reconstructed_outputs = self.clt(hidden_states) # Compute the reconstruction loss. recon_loss = 0.0 for i, (target, pred) in enumerate(zip(hidden_states, reconstructed_outputs)): recon_loss += F.mse_loss(pred, target) # Compute the sparsity loss. sparsity_loss = 0.0 for features in feature_activations: sparsity_loss += torch.mean(torch.tanh(self.config.sparsity_lambda * features)) # Total loss. loss = (self.config.reconstruction_loss_weight * recon_loss + self.config.sparsity_lambda * sparsity_loss) total_loss += loss total_recon_loss += recon_loss total_sparsity_loss += sparsity_loss # Average the losses. total_loss /= self.config.batch_size total_recon_loss /= self.config.batch_size total_sparsity_loss /= self.config.batch_size # Backward pass. optimizer.zero_grad() total_loss.backward() optimizer.step() # Log the progress. training_stats['total_losses'].append(total_loss.item()) training_stats['reconstruction_losses'].append(total_recon_loss.item()) training_stats['sparsity_losses'].append(total_sparsity_loss.item()) if step % 100 == 0: logger.info(f"Step {step}: Total Loss = {total_loss.item():.4f}, " f"Recon Loss = {total_recon_loss.item():.4f}, " f"Sparsity Loss = {total_sparsity_loss.item():.4f}") logger.info("CLT training completed") return training_stats def analyze_prompt(self, prompt: str, target_token_idx: int = -1) -> Dict: # Performs a complete analysis for a single prompt. logger.info(f"Analyzing prompt: '{prompt[:50]}...'") # Tokenize the prompt. inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) input_tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Get the model activations. with torch.no_grad(): outputs = self.model(**inputs, output_hidden_states=True) hidden_states = outputs.hidden_states[1:] # Forward pass through the CLT. feature_activations, reconstructed_outputs = self.clt(hidden_states) logger.info(" > Starting feature visualization and interpretation...") feature_visualizations = {} for layer_idx, features in enumerate(feature_activations): logger.info(f" - Processing Layer {layer_idx}...") layer_viz = {} # Analyze the top features for this layer. # features shape: [batch_size, seq_len, n_features] feature_importance = torch.mean(features, dim=(0, 1)) # Average over batch and sequence top_features = torch.topk(feature_importance, k=min(5, feature_importance.size(0))).indices for feat_idx in top_features: viz = self.feature_visualizer.visualize_feature( feat_idx.item(), layer_idx, features[0], input_tokens ) interpretation = self.feature_visualizer.interpret_feature( feat_idx.item(), layer_idx, viz, self.config.qwen_api_config ) viz['interpretation'] = interpretation layer_viz[f"feature_{feat_idx.item()}"] = viz feature_visualizations[f"layer_{layer_idx}"] = layer_viz # Construct the attribution graph. graph = self.attribution_graph.construct_graph( input_tokens, feature_activations, target_token_idx ) # Prune the graph. pruned_graph = self.attribution_graph.prune_graph(self.config.pruning_threshold) # Analyze the most important paths. important_paths = [] if len(pruned_graph.nodes()) > 0: # Find paths from embeddings to the output. embedding_nodes = [node for node, type_ in self.attribution_graph.node_types.items() if type_ == "embedding" and node in pruned_graph] output_nodes = [node for node, type_ in self.attribution_graph.node_types.items() if type_ == "output" and node in pruned_graph] for emb_node in embedding_nodes[:3]: # Top 3 embedding nodes for out_node in output_nodes: try: paths = list(nx.all_simple_paths(pruned_graph, emb_node, out_node, cutoff=5)) for path in paths[:2]: # Top 2 paths path_weight = 1.0 for i in range(len(path) - 1): edge_weight = self.attribution_graph.edge_weights.get( (path[i], path[i+1]), 0.0 ) path_weight *= abs(edge_weight) important_paths.append({ 'path': path, 'weight': path_weight, 'description': self._describe_path(path) }) except nx.NetworkXNoPath: continue # Sort paths by importance. important_paths.sort(key=lambda x: x['weight'], reverse=True) # Run targeted perturbation experiments for highlighted features. targeted_feature_ablation_results: List[Dict[str, Any]] = [] max_total_experiments = self.config.max_ablation_experiments per_layer_limit = self.config.ablation_features_per_layer total_run = 0 stop_all = False for layer_name, layer_features in feature_visualizations.items(): if stop_all: break try: layer_idx = int(layer_name.split('_')[1]) except (IndexError, ValueError): logger.warning(f"Unable to parse layer index from key '{layer_name}'. Skipping perturbation experiments for this layer.") continue feature_items = list(layer_features.items()) if per_layer_limit is not None: feature_items = feature_items[:per_layer_limit] for feature_name, feature_payload in feature_items: if max_total_experiments is not None and total_run >= max_total_experiments: stop_all = True break try: feature_idx = int(feature_name.split('_')[1]) except (IndexError, ValueError): logger.warning(f"Unable to parse feature index from key '{feature_name}'. Skipping perturbation experiment.") continue ablation = self.perturbation_experiments.feature_ablation_experiment( prompt, layer_idx, feature_idx, intervention_strength=self.config.intervention_strength, target_token_id=None, top_k=self.config.ablation_top_k_tokens, ) ablation.update({ 'layer_name': layer_name, 'feature_name': feature_name, 'feature_interpretation': feature_payload.get('interpretation'), 'feature_max_activation': feature_payload.get('max_activation'), }) targeted_feature_ablation_results.append(ablation) total_run += 1 # Random baseline perturbations for comparison. random_baseline_results: List[Dict[str, Any]] = [] baseline_trials = self.config.random_baseline_trials if baseline_trials and baseline_trials > 0: num_features = self.config.random_baseline_features or 1 for trial_idx in range(baseline_trials): seed = None if self.config.random_baseline_seed is not None: seed = self.config.random_baseline_seed + trial_idx random_result = self.perturbation_experiments.random_feature_ablation_experiment( prompt, num_features=num_features, intervention_strength=self.config.intervention_strength, target_token_id=None, top_k=self.config.ablation_top_k_tokens, seed=seed ) random_result['trial_index'] = trial_idx random_baseline_results.append(random_result) # Path-level ablations for the most important circuits. path_ablation_results: List[Dict[str, Any]] = [] max_paths = self.config.path_ablation_top_k or 0 extracted_paths: List[Dict[str, Any]] = [] if max_paths > 0 and important_paths: for path_info in important_paths[:max_paths]: feature_set = self._extract_feature_set_from_path(path_info.get('path', [])) if not feature_set: continue path_result = self.perturbation_experiments.feature_set_ablation_experiment( prompt, feature_set=feature_set, intervention_strength=self.config.intervention_strength, target_token_id=None, top_k=self.config.ablation_top_k_tokens, ablation_label="path", extra_metadata={ 'path_nodes': path_info.get('path'), 'path_description': path_info.get('description'), 'path_weight': path_info.get('weight') } ) path_ablation_results.append(path_result) enriched_path_info = path_info.copy() enriched_path_info['feature_set'] = feature_set extracted_paths.append(enriched_path_info) random_path_baseline_results: List[Dict[str, Any]] = [] path_baseline_trials = self.config.random_path_baseline_trials if path_baseline_trials and path_baseline_trials > 0 and extracted_paths: rng = random.Random(self.config.random_baseline_seed) available_nodes = [ data for data in self.attribution_graph.node_types.items() if data[1] == "feature" ] for trial in range(path_baseline_trials): selected_path = extracted_paths[min(trial % len(extracted_paths), len(extracted_paths) - 1)] target_length = len(selected_path.get('feature_set', [])) source_layers = [layer for layer, _ in selected_path.get('feature_set', [])] min_layer = min(source_layers) if source_layers else 0 max_layer = max(source_layers) if source_layers else self.clt.n_layers - 1 excluded_keys = { (layer, feature) for layer, feature in selected_path.get('feature_set', []) } random_feature_set: List[Tuple[int, int]] = [] attempts = 0 while len(random_feature_set) < target_length and attempts < target_length * 5: attempts += 1 if not available_nodes: break node_name, node_type = rng.choice(available_nodes) metadata = self.attribution_graph.feature_metadata.get(node_name) if metadata is None: continue if metadata['layer'] < min_layer or metadata['layer'] > max_layer: continue key = (metadata['layer'], metadata['feature_index']) if key in excluded_keys: continue if key not in random_feature_set: random_feature_set.append(key) if not random_feature_set: continue if len(random_feature_set) < max(1, target_length): continue random_path_result = self.perturbation_experiments.feature_set_ablation_experiment( prompt, feature_set=random_feature_set, intervention_strength=self.config.intervention_strength, target_token_id=None, top_k=self.config.ablation_top_k_tokens, ablation_label="random_path_baseline", extra_metadata={ 'trial_index': trial, 'sampled_feature_set': random_feature_set, 'reference_path_weight': selected_path.get('weight') } ) random_path_baseline_results.append(random_path_result) targeted_summary = self._summarize_ablation_results(targeted_feature_ablation_results) random_summary = self._summarize_ablation_results(random_baseline_results) path_summary = self._summarize_ablation_results(path_ablation_results) random_path_summary = self._summarize_ablation_results(random_path_baseline_results) summary_statistics = { 'targeted': targeted_summary, 'random_baseline': random_summary, 'path': path_summary, 'random_path_baseline': random_path_summary, 'target_minus_random_abs_probability_change': targeted_summary.get('avg_abs_probability_change', 0.0) - random_summary.get('avg_abs_probability_change', 0.0), 'target_flip_rate_minus_random': targeted_summary.get('flip_rate', 0.0) - random_summary.get('flip_rate', 0.0), 'path_minus_random_abs_probability_change': path_summary.get('avg_abs_probability_change', 0.0) - random_path_summary.get('avg_abs_probability_change', 0.0), 'path_flip_rate_minus_random': path_summary.get('flip_rate', 0.0) - random_path_summary.get('flip_rate', 0.0) } results = { 'prompt': prompt, 'input_tokens': input_tokens, 'feature_visualizations': feature_visualizations, 'full_graph_stats': { 'n_nodes': len(graph.nodes()), 'n_edges': len(graph.edges()), 'node_types': dict(self.attribution_graph.node_types) }, 'pruned_graph_stats': { 'n_nodes': len(pruned_graph.nodes()), 'n_edges': len(pruned_graph.edges()) }, 'important_paths': important_paths[:5], # Top 5 paths 'graph': pruned_graph, 'perturbation_experiments': targeted_feature_ablation_results, 'random_baseline_experiments': random_baseline_results, 'path_ablation_experiments': path_ablation_results, 'random_path_baseline_experiments': random_path_baseline_results, 'summary_statistics': summary_statistics } return results def _extract_feature_set_from_path(self, path: List[str]) -> List[Tuple[int, int]]: feature_set: List[Tuple[int, int]] = [] seen: Set[Tuple[int, int]] = set() for node in path: if not isinstance(node, str): continue if not node.startswith("feat_"): continue parts = node.split('_') try: layer_str = parts[1] # e.g., "L0" feature_str = parts[3] # e.g., "F123" layer_idx = int(layer_str[1:]) feature_idx = int(feature_str[1:]) except (IndexError, ValueError): continue key = (layer_idx, feature_idx) if key not in seen: seen.add(key) feature_set.append(key) return feature_set def _summarize_ablation_results(self, experiments: List[Dict[str, Any]]) -> Dict[str, Any]: summary = { 'count': len(experiments), 'avg_probability_change': 0.0, 'avg_abs_probability_change': 0.0, 'std_probability_change': 0.0, 'avg_logit_change': 0.0, 'avg_abs_logit_change': 0.0, 'std_logit_change': 0.0, 'avg_kl_divergence': 0.0, 'avg_entropy_change': 0.0, 'avg_hidden_state_delta_norm': 0.0, 'avg_hidden_state_relative_change': 0.0, 'flip_rate': 0.0, 'count_flipped': 0 } if not experiments: return summary probability_changes = np.array([exp.get('probability_change', 0.0) for exp in experiments], dtype=float) logit_changes = np.array([exp.get('logit_change', 0.0) for exp in experiments], dtype=float) kl_divergences = np.array([exp.get('kl_divergence', 0.0) for exp in experiments], dtype=float) entropy_changes = np.array([exp.get('entropy_change', 0.0) for exp in experiments], dtype=float) hidden_norms = np.array([exp.get('hidden_state_delta_norm', 0.0) for exp in experiments], dtype=float) hidden_relative = np.array([exp.get('hidden_state_relative_change', 0.0) for exp in experiments], dtype=float) flip_flags = np.array([1.0 if exp.get('ablation_flips_top_prediction') else 0.0 for exp in experiments], dtype=float) # Helper to safely compute mean/std ignoring NaNs def safe_mean(arr): with np.errstate(all='ignore'): m = np.nanmean(arr) return float(m) if np.isfinite(m) else 0.0 def safe_std(arr): with np.errstate(all='ignore'): s = np.nanstd(arr) return float(s) if np.isfinite(s) else 0.0 summary.update({ 'avg_probability_change': safe_mean(probability_changes), 'avg_abs_probability_change': safe_mean(np.abs(probability_changes)), 'std_probability_change': safe_std(probability_changes), 'avg_logit_change': safe_mean(logit_changes), 'avg_abs_logit_change': safe_mean(np.abs(logit_changes)), 'std_logit_change': safe_std(logit_changes), 'avg_kl_divergence': safe_mean(kl_divergences), 'avg_entropy_change': safe_mean(entropy_changes), 'avg_hidden_state_delta_norm': safe_mean(hidden_norms), 'avg_hidden_state_relative_change': safe_mean(hidden_relative), 'flip_rate': safe_mean(flip_flags), 'count_flipped': int(np.round(np.nansum(flip_flags))) }) return summary def analyze_prompts_batch(self, prompts: List[str]) -> Dict[str, Any]: analyses: Dict[str, Dict[str, Any]] = {} aggregated_targeted: List[Dict[str, Any]] = [] aggregated_random: List[Dict[str, Any]] = [] aggregated_path: List[Dict[str, Any]] = [] for idx, prompt in enumerate(prompts): logger.info(f"[Batch Eval] Processing prompt {idx + 1}/{len(prompts)}") analysis = self.analyze_prompt(prompt) key = f"prompt_{idx + 1}" analyses[key] = analysis aggregated_targeted.extend(analysis.get('perturbation_experiments', [])) aggregated_random.extend(analysis.get('random_baseline_experiments', [])) aggregated_path.extend(analysis.get('path_ablation_experiments', [])) aggregate_summary = { 'targeted': self._summarize_ablation_results(aggregated_targeted), 'random_baseline': self._summarize_ablation_results(aggregated_random), 'path': self._summarize_ablation_results(aggregated_path), 'random_path_baseline': self._summarize_ablation_results( [ exp for analysis in analyses.values() for exp in analysis.get('random_path_baseline_experiments', []) ] ) } aggregate_summary['target_minus_random_abs_probability_change'] = ( aggregate_summary['targeted'].get('avg_abs_probability_change', 0.0) - aggregate_summary['random_baseline'].get('avg_abs_probability_change', 0.0) ) aggregate_summary['target_flip_rate_minus_random'] = ( aggregate_summary['targeted'].get('flip_rate', 0.0) - aggregate_summary['random_baseline'].get('flip_rate', 0.0) ) aggregate_summary['path_minus_random_abs_probability_change'] = ( aggregate_summary['path'].get('avg_abs_probability_change', 0.0) - aggregate_summary['random_path_baseline'].get('avg_abs_probability_change', 0.0) ) aggregate_summary['path_flip_rate_minus_random'] = ( aggregate_summary['path'].get('flip_rate', 0.0) - aggregate_summary['random_path_baseline'].get('flip_rate', 0.0) ) return { 'analyses': analyses, 'aggregate_summary': aggregate_summary, 'prompt_texts': prompts } def _describe_path(self, path: List[str]) -> str: # Generates a human-readable description of a path. descriptions = [] for node in path: if self.attribution_graph.node_types[node] == "embedding": token = node.split('_')[2] descriptions.append(f"Token '{token}'") elif self.attribution_graph.node_types[node] == "feature": parts = node.split('_') layer = parts[1][1:] # Remove 'L' feature = parts[3][1:] # Remove 'F' # Try to get the interpretation. key = f"L{layer}_F{feature}" interpretation = self.feature_visualizer.feature_interpretations.get(key, "unknown") descriptions.append(f"Feature L{layer}F{feature} ({interpretation})") elif self.attribution_graph.node_types[node] == "output": descriptions.append("Output") return " โ†’ ".join(descriptions) def save_results(self, results: Dict, save_path: str): # Saves the analysis results to a file. serializable_results = copy.deepcopy(results) if 'graph' in serializable_results: serializable_results['graph'] = nx.node_link_data(serializable_results['graph']) analyses = serializable_results.get('analyses', {}) for key, analysis in analyses.items(): if 'graph' in analysis: analysis['graph'] = nx.node_link_data(analysis['graph']) with open(save_path, 'w') as f: json.dump(serializable_results, f, indent=2, default=str) logger.info(f"Results saved to {save_path}") def save_clt(self, path: str): # Saves the trained CLT model. torch.save(self.clt.state_dict(), path) logger.info(f"CLT model saved to {path}") def load_clt(self, path: str): # Loads a trained CLT model. self.clt.load_state_dict(torch.load(path, map_location=self.device)) self.clt.to(self.device) self.clt.eval() # Set the model to evaluation mode logger.info(f"Loaded CLT model from {path}") # --- Configuration --- MAX_SEQ_LEN = 256 N_FEATURES_PER_LAYER = 512 TRAINING_STEPS = 2500 BATCH_SIZE = 64 LEARNING_RATE = 1e-3 # Prompts for generating the final analysis. ANALYSIS_PROMPTS = [ "The capital of France is", "def factorial(n):", "The literary device in the phrase 'The wind whispered through the trees' is" ] # A larger set of prompts for training. TRAINING_PROMPTS = [ "The capital of France is", "To be or not to be, that is the", "A stitch in time saves", "The first person to walk on the moon was", "The chemical formula for water is H2O.", "Translate to German: 'The cat sits on the mat.'", "def factorial(n):", "import numpy as np", "The main ingredients in a pizza are", "What is the powerhouse of the cell?", "The equation E=mc^2 relates energy to", "Continue the story: Once upon a time, there was a", "Classify the sentiment: 'I am overjoyed!'", "Extract the entities: 'Apple Inc. is in Cupertino.'", "What is the next number: 2, 4, 8, 16, __?", "A rolling stone gathers no", "The opposite of hot is", "import torch", "import pandas as pd", "class MyClass:", "def __init__(self):", "The primary colors are", "What is the capital of Japan?", "Who wrote 'Hamlet'?", "The square root of 64 is", "The sun rises in the", "The Pacific Ocean is the largest ocean on Earth.", "The mitochondria is the powerhouse of the cell.", "What is the capital of Mongolia?", "The movie 'The Matrix' can be classified into the following genre:", "The French translation of 'I would like to order a coffee, please.' is:", "The literary device in the phrase 'The wind whispered through the trees' is", "A Python function that calculates the factorial of a number is:", "The main ingredient in a Negroni cocktail is", "Summarize the plot of 'Hamlet' in one sentence:", "The sentence 'The cake was eaten by the dog' is in the following voice:", "A good headline for an article about a new breakthrough in battery technology would be:" ] # --- Qwen API for Feature Interpretation --- @torch.no_grad() def get_feature_interpretation_with_qwen( api_config: dict, top_tokens: list[str], feature_name: str, layer_index: int, max_retries: int = 3, initial_backoff: float = 2.0 ) -> str: # Generates a high-quality interpretation for a feature using the Qwen API. if not api_config or not api_config.get('api_key'): logger.warning("Qwen API not configured. Skipping interpretation.") return "API not configured" headers = { "Authorization": f"Bearer {api_config['api_key']}", "Content-Type": "application/json" } # Create a specialized prompt. prompt_text = f""" You are an expert in transformer interpretability. A feature in a language model (feature '{feature_name}' at layer {layer_index}) is most strongly activated by the following tokens: {', '.join(f"'{token}'" for token in top_tokens)} Based *only* on these tokens, what is the most likely function or role of this feature? Your answer must be a short, concise phrase (e.g., "Detecting proper nouns", "Identifying JSON syntax", "Completing lists", "Recognizing negative sentiment"). Do not write a full sentence. """ data = { "model": api_config["model"], "messages": [ { "role": "user", "content": [{"type": "text", "text": prompt_text}] } ], "max_tokens": 50, "temperature": 0.1, "top_p": 0.9, "seed": 42 } logger.info(f" > Interpreting {feature_name} (Layer {layer_index})...") for attempt in range(max_retries): try: logger.info(f" - Attempt {attempt + 1}/{max_retries}: Sending request to Qwen API...") response = requests.post( f"{api_config['api_endpoint']}/chat/completions", headers=headers, json=data, timeout=60 ) response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx) result = response.json() interpretation = result["choices"][0]["message"]["content"].strip() # Remove quotes from the output. if interpretation.startswith('"') and interpretation.endswith('"'): interpretation = interpretation[1:-1] logger.info(f" - Success! Interpretation: '{interpretation}'") return interpretation except requests.exceptions.RequestException as e: logger.warning(f" - Qwen API request failed (Attempt {attempt + 1}/{max_retries}): {e}") if attempt < max_retries - 1: backoff_time = initial_backoff * (2 ** attempt) logger.info(f" - Retrying in {backoff_time:.1f} seconds...") time.sleep(backoff_time) else: logger.error(" - Max retries reached. Failing.") return f"API Error: {e}" except (KeyError, IndexError) as e: logger.error(f" - Failed to parse Qwen API response: {e}") return "API Error: Invalid response format" finally: # Add a delay to respect API rate limits. time.sleep(2.1) return "API Error: Max retries exceeded" def train_transcoder(transcoder, model, tokenizer, training_prompts, device, steps=1000, batch_size=16, optimizer=None): # Trains the Cross-Layer Transcoder. transcoder.train() # Use a progress bar for visual feedback. progress_bar = tqdm(range(steps), desc="Training CLT") for step in progress_bar: # Get a random batch of prompts. batch_prompts = random.choices(training_prompts, k=batch_size) # Tokenize the batch. inputs = tokenizer( batch_prompts, return_tensors="pt", padding=True, truncation=True, max_length=MAX_SEQ_LEN ) inputs = {k: v.to(device) for k, v in inputs.items()} # Get the model activations. with torch.no_grad(): outputs = model(**inputs, output_hidden_states=True) hidden_states = outputs.hidden_states[1:] # Forward pass through the CLT. feature_activations, reconstructed_outputs = transcoder(hidden_states) # Compute the reconstruction loss. recon_loss = 0.0 for i, (target, pred) in enumerate(zip(hidden_states, reconstructed_outputs)): recon_loss += F.mse_loss(pred, target) # Compute the sparsity loss. sparsity_loss = 0.0 for features in feature_activations: sparsity_loss += torch.mean(torch.tanh(0.01 * features)) # Use config.sparsity_lambda # Total loss. loss = (0.8 * recon_loss + 0.2 * sparsity_loss) # Use config.reconstruction_loss_weight if optimizer: optimizer.zero_grad() loss.backward() optimizer.step() progress_bar.set_postfix({ "Recon Loss": f"{recon_loss.item():.4f}", "Sparsity Loss": f"{sparsity_loss.item():.4f}", "Total Loss": f"{loss.item():.4f}" }) def generate_feature_visualizations(transcoder, model, tokenizer, prompt, device, qwen_api_config=None, graph_config: Optional[AttributionGraphConfig] = None): # Generates feature visualizations and interpretations for a prompt. # Tokenize the prompt. inputs = tokenizer( prompt, return_tensors="pt", padding=True, truncation=True, max_length=MAX_SEQ_LEN ) inputs = {k: v.to(device) for k, v in inputs.items()} # Get the model activations. with torch.no_grad(): outputs = model(**inputs, output_hidden_states=True) hidden_states = outputs.hidden_states[1:] # Forward pass through the CLT. feature_activations, reconstructed_outputs = transcoder(hidden_states) # Visualize the features. feature_visualizations = {} for layer_idx, features in enumerate(feature_activations): layer_viz = {} # Analyze the top features for this layer. # features shape: [batch_size, seq_len, n_features] feature_importance = torch.mean(features, dim=(0, 1)) # Average over batch and sequence top_features = torch.topk(feature_importance, k=min(5, feature_importance.size(0))).indices for feat_idx in top_features: viz = FeatureVisualizer(tokenizer).visualize_feature( feat_idx.item(), layer_idx, features[0], tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) ) interpretation = FeatureVisualizer(tokenizer).interpret_feature( feat_idx.item(), layer_idx, viz, qwen_api_config ) viz['interpretation'] = interpretation layer_viz[f"feature_{feat_idx.item()}"] = viz feature_visualizations[f"layer_{layer_idx}"] = layer_viz # Construct the attribution graph. if graph_config is None: graph_config = AttributionGraphConfig() attribution_graph = AttributionGraph(transcoder, tokenizer, graph_config) graph = attribution_graph.construct_graph( tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]), feature_activations, -1 # No target token for visualization ) # Prune the graph. pruned_graph = attribution_graph.prune_graph(0.8) # Use config.pruning_threshold # Analyze the most important paths. important_paths = [] if len(pruned_graph.nodes()) > 0: # Find paths from embeddings to the output. embedding_nodes = [node for node, type_ in attribution_graph.node_types.items() if type_ == "embedding" and node in pruned_graph] output_nodes = [node for node, type_ in attribution_graph.node_types.items() if type_ == "output" and node in pruned_graph] for emb_node in embedding_nodes[:3]: # Top 3 embedding nodes for out_node in output_nodes: try: paths = list(nx.all_simple_paths(pruned_graph, emb_node, out_node, cutoff=5)) for path in paths[:2]: # Top 2 paths path_weight = 1.0 for i in range(len(path) - 1): edge_weight = attribution_graph.edge_weights.get( (path[i], path[i+1]), 0.0 ) path_weight *= abs(edge_weight) important_paths.append({ 'path': path, 'weight': path_weight, 'description': attribution_graph._describe_path(path) }) except nx.NetworkXNoPath: continue # Sort paths by importance. important_paths.sort(key=lambda x: x['weight'], reverse=True) return { "prompt": prompt, "full_graph_stats": { "n_nodes": len(graph.nodes()), "n_edges": len(graph.edges()), "node_types": dict(attribution_graph.node_types) }, "pruned_graph_stats": { "n_nodes": len(pruned_graph.nodes()), "n_edges": len(pruned_graph.edges()) }, "feature_visualizations": feature_visualizations, "important_paths": important_paths[:5] # Top 5 paths } def main(): # Main function to run the analysis for a single prompt. # Set a seed for reproducibility. set_seed() # --- Argument Parser --- parser = argparse.ArgumentParser(description="Run Attribution Graph analysis for a single prompt.") parser.add_argument( '--prompt-index', type=int, required=True, help=f"The 0-based index of the prompt to analyze from the ANALYSIS_PROMPTS list (0 to {len(ANALYSIS_PROMPTS) - 1})." ) parser.add_argument( '--force-retrain-clt', action='store_true', help="Force re-training of the Cross-Layer Transcoder, even if a saved model exists." ) parser.add_argument( '--batch-eval', action='store_true', help="Analyze all predefined prompts and compute aggregate faithfulness metrics." ) args = parser.parse_args() prompt_idx = args.prompt_index if not (0 <= prompt_idx < len(ANALYSIS_PROMPTS)): print(f"โŒ Error: --prompt-index must be between 0 and {len(ANALYSIS_PROMPTS) - 1}.") return # Get the API config from the utility function. qwen_api_config = init_qwen_api() # Configuration - Use consistent settings matching trained CLT config = AttributionGraphConfig( model_path="./models/OLMo-2-1124-7B", n_features_per_layer=512, # Match trained CLT training_steps=500, batch_size=4, max_seq_length=256, learning_rate=1e-4, sparsity_lambda=1e-3, # Match training (L1 sparsity) graph_feature_activation_threshold=0.01, graph_edge_weight_threshold=0.003, graph_max_features_per_layer=40, graph_max_edges_per_node=20, qwen_api_config=qwen_api_config ) print("Attribution Graphs for OLMo2 7B - Single Prompt Pipeline") print("=" * 50) print(f"Model path: {config.model_path}") print(f"Device: {config.device}") try: # Initialize the full pipeline. print("๐Ÿš€ Initializing Attribution Graphs Pipeline...") pipeline = AttributionGraphsPipeline(config) print("โœ“ Pipeline initialized successfully") print() # Load an existing CLT model or train a new one. if os.path.exists(CLT_SAVE_PATH) and not args.force_retrain_clt: print(f"๐Ÿง  Loading existing CLT model from {CLT_SAVE_PATH}...") pipeline.load_clt(CLT_SAVE_PATH) print("โœ“ CLT model loaded successfully.") else: if args.force_retrain_clt and os.path.exists(CLT_SAVE_PATH): print("๏ฟฝ๏ฟฝโ€โ™‚๏ธ --force-retrain-clt flag is set. Overwriting existing model.") # Train a new CLT model. print("๐Ÿ“š Training a new CLT model...") print(f" Training on {len(TRAINING_PROMPTS)} example texts...") training_stats = pipeline.train_clt(TRAINING_PROMPTS) print("โœ“ CLT training completed.") # Save the training statistics. stats_save_path = os.path.join(RESULTS_DIR, "clt_training_stats.json") with open(stats_save_path, 'w') as f: json.dump(training_stats, f, indent=2) print(f" Saved training stats to {stats_save_path}") # Save the new model. pipeline.save_clt(CLT_SAVE_PATH) print(f" Saved trained model to {CLT_SAVE_PATH} for future use.") print() if args.batch_eval: print("๐Ÿ“Š Running batch faithfulness evaluation across all prompts...") batch_payload = pipeline.analyze_prompts_batch(ANALYSIS_PROMPTS) final_results = copy.deepcopy(batch_payload) final_results['config'] = config.__dict__ final_results['timestamp'] = str(time.time()) for analysis_entry in final_results['analyses'].values(): analysis_entry.pop('graph', None) batch_save_path = os.path.join(RESULTS_DIR, "attribution_graphs_batch_results.json") pipeline.save_results(final_results, batch_save_path) print(f"๐Ÿ’พ Batch results saved to {batch_save_path}") aggregate_summary = batch_payload['aggregate_summary'] targeted_summary = aggregate_summary.get('targeted', {}) random_summary = aggregate_summary.get('random_baseline', {}) path_summary = aggregate_summary.get('path', {}) def _format_summary(label: str, summary: Dict[str, Any]) -> str: return ( f"{label}: count={summary.get('count', 0)}, " f"avg|ฮ”p|={summary.get('avg_abs_probability_change', 0.0):.4f}, " f"flip_rate={summary.get('flip_rate', 0.0):.2%}" ) print("๐Ÿ“ˆ Aggregate faithfulness summary") print(f" {_format_summary('Targeted', targeted_summary)}") print(f" {_format_summary('Random baseline', random_summary)}") print(f" {_format_summary('Path', path_summary)}") print(f" {_format_summary('Random path baseline', aggregate_summary.get('random_path_baseline', {}))}") diff_abs = aggregate_summary.get('target_minus_random_abs_probability_change', 0.0) diff_flip = aggregate_summary.get('target_flip_rate_minus_random', 0.0) path_diff_abs = aggregate_summary.get('path_minus_random_abs_probability_change', 0.0) path_diff_flip = aggregate_summary.get('path_flip_rate_minus_random', 0.0) print(f" Targeted vs Random |ฮ”p| difference: {diff_abs:.4f}") print(f" Targeted vs Random flip rate difference: {diff_flip:.4f}") print(f" Path vs Random path |ฮ”p| difference: {path_diff_abs:.4f}") print(f" Path vs Random path flip rate difference: {path_diff_flip:.4f}") print("\n๐ŸŽ‰ Batch evaluation completed successfully!") return # Analyze the selected prompt. prompt_to_analyze = ANALYSIS_PROMPTS[prompt_idx] print(f"๐Ÿ” Analyzing prompt {prompt_idx + 1}/{len(ANALYSIS_PROMPTS)}: '{prompt_to_analyze}'") analysis = pipeline.analyze_prompt(prompt_to_analyze, target_token_idx=-1) # Display the key results. print(f" โœ“ Tokenized into {len(analysis['input_tokens'])} tokens") print(f" โœ“ Full graph: {analysis['full_graph_stats']['n_nodes']} nodes, {analysis['full_graph_stats']['n_edges']} edges") print(f" โœ“ Pruned graph: {analysis['pruned_graph_stats']['n_nodes']} nodes, {analysis['pruned_graph_stats']['n_edges']} edges") # Show the top features. print(" ๐Ÿ“Š Top active features:") feature_layers_items = list(analysis['feature_visualizations'].items()) if config.summary_max_layers is not None: feature_layers_items = feature_layers_items[:config.summary_max_layers] for layer_name, layer_features in feature_layers_items: print(f" {layer_name}:") feature_items = layer_features.items() if config.summary_features_per_layer is not None: feature_items = list(feature_items)[:config.summary_features_per_layer] for feat_name, feat_data in feature_items: print(f" {feat_name}: {feat_data['interpretation']} (max: {feat_data['max_activation']:.3f})") print() # Summarize perturbation experiments and baselines. print("๐Ÿงช Targeted feature ablations:") targeted_results = analysis.get('perturbation_experiments', []) if targeted_results: for experiment in targeted_results: layer_name = experiment.get('layer_name', f"L{experiment.get('feature_set', [{}])[0].get('layer', '?')}") feature_name = experiment.get('feature_name', f"F{experiment.get('feature_set', [{}])[0].get('feature', '?')}") prob_delta = experiment.get('probability_change', 0.0) logit_delta = experiment.get('logit_change', 0.0) flips = experiment.get('ablation_flips_top_prediction', False) print(f" {layer_name}/{feature_name}: ฮ”p={prob_delta:.4f}, ฮ”logit={logit_delta:.4f}, flips_top={flips}") else: print(" - No targeted ablations were recorded.") print("\n๐ŸŽฒ Random baseline ablations:") random_baseline = analysis.get('random_baseline_experiments', []) if random_baseline: for experiment in random_baseline: prob_delta = experiment.get('probability_change', 0.0) logit_delta = experiment.get('logit_change', 0.0) flips = experiment.get('ablation_flips_top_prediction', False) trial_idx = experiment.get('trial_index', '?') print(f" Trial {trial_idx}: ฮ”p={prob_delta:.4f}, ฮ”logit={logit_delta:.4f}, flips_top={flips}") else: print(" - No random baseline trials were run.") print("\n๐Ÿ›ค๏ธ Path ablations:") path_results = analysis.get('path_ablation_experiments', []) if path_results: for path_exp in path_results: description = path_exp.get('path_description', 'Path') prob_delta = path_exp.get('probability_change', 0.0) logit_delta = path_exp.get('logit_change', 0.0) flips = path_exp.get('ablation_flips_top_prediction', False) print(f" {description}: ฮ”p={prob_delta:.4f}, ฮ”logit={logit_delta:.4f}, flips_top={flips}") else: print(" - No path ablations were run.") summary_stats = analysis.get('summary_statistics', {}) targeted_summary = summary_stats.get('targeted', {}) random_summary = summary_stats.get('random_baseline', {}) path_summary = summary_stats.get('path', {}) random_path_summary = summary_stats.get('random_path_baseline', {}) print("\n๐Ÿ“ˆ Summary statistics:") print(f" Targeted: avg|ฮ”p|={targeted_summary.get('avg_abs_probability_change', 0.0):.4f}, flip_rate={targeted_summary.get('flip_rate', 0.0):.2%}") print(f" Random baseline: avg|ฮ”p|={random_summary.get('avg_abs_probability_change', 0.0):.4f}, flip_rate={random_summary.get('flip_rate', 0.0):.2%}") print(f" Path: avg|ฮ”p|={path_summary.get('avg_abs_probability_change', 0.0):.4f}, flip_rate={path_summary.get('flip_rate', 0.0):.2%}") print(f" Random path baseline: avg|ฮ”p|={random_path_summary.get('avg_abs_probability_change', 0.0):.4f}, flip_rate={random_path_summary.get('flip_rate', 0.0):.2%}") print(f" Targeted vs Random |ฮ”p| diff: {summary_stats.get('target_minus_random_abs_probability_change', 0.0):.4f}") print(f" Targeted vs Random flip diff: {summary_stats.get('target_flip_rate_minus_random', 0.0):.4f}") print(f" Path vs Random path |ฮ”p| diff: {summary_stats.get('path_minus_random_abs_probability_change', 0.0):.4f}") print(f" Path vs Random path flip diff: {summary_stats.get('path_flip_rate_minus_random', 0.0):.4f}") print("\nโœ“ Faithfulness experiments summarized\n") # Generate a visualization for the prompt. print("๐Ÿ“ˆ Generating visualization...") if 'graph' in analysis and analysis['pruned_graph_stats']['n_nodes'] > 0: viz_path = os.path.join(RESULTS_DIR, f"attribution_graph_prompt_{prompt_idx + 1}.png") pipeline.attribution_graph.visualize_graph(analysis['graph'], save_path=viz_path) print(f" โœ“ Graph visualization saved to {viz_path}") else: print(" - Skipping visualization as no graph was generated or it was empty.") # Save the results in a format for the web app. save_path = os.path.join(RESULTS_DIR, f"attribution_graphs_results_prompt_{prompt_idx + 1}.json") # Create a JSON file that can be merged with others. final_results = { "analyses": { f"prompt_{prompt_idx + 1}": analysis }, "config": config.__dict__, "timestamp": str(time.time()) } # The web page doesn't use the graph object, so remove it. if 'graph' in final_results['analyses'][f"prompt_{prompt_idx + 1}"]: del final_results['analyses'][f"prompt_{prompt_idx + 1}"]['graph'] pipeline.save_results(final_results, save_path) print(f"๐Ÿ’พ Results saved to {save_path}") print("\n๐ŸŽ‰ Analysis for this prompt completed successfully!") except Exception as e: print(f"โŒ Error during execution: {e}") import traceback traceback.print_exc() if __name__ == "__main__": main()