Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import networkx as nx | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| import logging | |
| from typing import List, Tuple | |
| from pathlib import Path | |
| import math | |
| # Ensure we can import the pipeline | |
| sys.path.append(str(Path(__file__).resolve().parent.parent)) | |
| from circuit_analysis.attribution_graphs_olmo import ( | |
| AttributionGraphsPipeline, | |
| AttributionGraphConfig, | |
| ANALYSIS_PROMPTS, | |
| AttributionGraph | |
| ) | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| def compute_cpr(k_values: List[float], f_values: List[float]) -> float: | |
| """ | |
| Compute CPR (Integrated Circuit Performance Ratio) using the trapezoidal rule. | |
| CPR = Integral of f(C_k) dk | |
| """ | |
| cpr = 0.0 | |
| for i in range(len(k_values) - 1): | |
| cpr += 0.5 * (f_values[i] + f_values[i+1]) * (k_values[i+1] - k_values[i]) | |
| return cpr | |
| def compute_cmd(k_values: List[float], f_values: List[float]) -> float: | |
| """ | |
| Compute CMD (Integrated Circuit-Model Distance) using the trapezoidal rule. | |
| CMD = Integral of |1 - f(C_k)| dk | |
| """ | |
| cmd = 0.0 | |
| for i in range(len(k_values) - 1): | |
| y0 = abs(1.0 - f_values[i]) | |
| y1 = abs(1.0 - f_values[i+1]) | |
| cmd += 0.5 * (y0 + y1) * (k_values[i+1] - k_values[i]) | |
| return cmd | |
| def get_active_features_from_graph(graph: nx.DiGraph) -> List[Tuple[int, int]]: | |
| """ | |
| Extracts the list of feature nodes (as layer_idx, feature_idx tuples) from the graph. | |
| """ | |
| features = [] | |
| for node in graph.nodes(): | |
| if node.startswith("feat_"): | |
| parts = node.split('_') | |
| try: | |
| # Format: feat_L{layer}_T{token}_F{feature} | |
| layer_idx = int(parts[1][1:]) | |
| feature_idx = int(parts[3][1:]) | |
| # We only care about unique (layer, feature) pairs for ablation | |
| features.append((layer_idx, feature_idx)) | |
| except (IndexError, ValueError): | |
| continue | |
| return list(set(features)) | |
| def calculate_graph_importance(attribution_graph_obj: AttributionGraph, graph: nx.DiGraph) -> List[Tuple[str, float]]: | |
| """ | |
| Calculates the importance of each feature node in the graph based on edge weights. | |
| Returns a list of (node_id, importance_score) sorted by importance descending. | |
| """ | |
| node_importance = {} | |
| # Identify feature nodes | |
| feature_nodes = [n for n in graph.nodes() if attribution_graph_obj.node_types.get(n) == "feature"] | |
| # Calculate importance as sum of absolute weights of connected edges | |
| for node in feature_nodes: | |
| importance = 0.0 | |
| # Outgoing edges | |
| for _, target in graph.out_edges(node): | |
| weight = attribution_graph_obj.edge_weights.get((node, target), 0.0) | |
| importance += abs(weight) | |
| # Incoming edges? MIB usually focuses on "importance" for the task. | |
| # Using sum of absolute edge weights is a standard proxy. | |
| # attribution_graphs_olmo.py prune_graph uses sum of all connected edge weights (in and out). | |
| for source, _ in graph.in_edges(node): | |
| weight = attribution_graph_obj.edge_weights.get((source, node), 0.0) | |
| importance += abs(weight) | |
| node_importance[node] = importance | |
| return sorted(node_importance.items(), key=lambda x: x[1], reverse=True) | |
| def get_edges_count(graph: nx.DiGraph, nodes: List[str]) -> int: | |
| """ | |
| Returns the number of edges in the subgraph induced by the given nodes | |
| (plus edges to output/embedding if we consider them part of the circuit context). | |
| However, strictly following "fraction of total edges": | |
| We should count edges where BOTH source and target are in the kept set (including embeddings/output). | |
| """ | |
| # Assuming embeddings and output are always "kept" or don't count towards the quota | |
| # if we only ablate features. | |
| # But for the metric k = |C|/|N|, we need a consistent definition. | |
| # Let's define |C| as the number of edges in the subgraph induced by (Selected Features + Embeddings + Output). | |
| nodes_set = set(nodes) | |
| count = 0 | |
| for u, v in graph.edges(): | |
| if u in nodes_set and v in nodes_set: | |
| count += 1 | |
| return count | |
| def run_cpr_cmd_analysis(pipeline: AttributionGraphsPipeline, prompt_idx: int): | |
| """ | |
| Compute CPR and CMD for a given prompt, using: | |
| - Universe: all feature nodes present in the attribution graph | |
| - Metric m: logit(target) only (no foil) | |
| - Interventions: ablation of feature sets with intervention_strength=1.0 | |
| """ | |
| prompt = ANALYSIS_PROMPTS[prompt_idx] | |
| logger.info(f"Analyzing prompt {prompt_idx}: '{prompt}'") | |
| # Build/prune the attribution graph for this prompt | |
| pipeline.analyze_prompt(prompt) | |
| full_graph = pipeline.attribution_graph.graph | |
| # Baseline: run once to get logits & feature activations | |
| baseline_data = pipeline.perturbation_experiments._prepare_inputs(prompt, top_k=1) | |
| target_token_id = baseline_data['baseline_top_tokens'].indices[0].item() | |
| baseline_logits = baseline_data['baseline_last_token_logits'] | |
| m_N = baseline_logits[target_token_id].item() | |
| logger.info( | |
| f"Baseline m(N) = {m_N:.4f} " | |
| f"(Token: {pipeline.tokenizer.decode([target_token_id])})" | |
| ) | |
| # Universe: all feature nodes in the graph | |
| universe_features = get_active_features_from_graph(full_graph) | |
| logger.info(f"Graph Universe size: {len(universe_features)} features") | |
| if not universe_features: | |
| logger.warning("No features found in graph. Skipping.") | |
| return None | |
| # Empty circuit: ablate all universe features | |
| empty_res = pipeline.perturbation_experiments.feature_set_ablation_experiment( | |
| prompt, | |
| feature_set=universe_features, | |
| intervention_strength=1.0, | |
| target_token_id=target_token_id | |
| ) | |
| m_empty = empty_res["ablated_logit"] | |
| logger.info(f"Empty m(Ø) = {m_empty:.4f}") | |
| if not math.isfinite(m_empty): | |
| logger.warning( | |
| f"m_empty is non-finite ({m_empty}) for prompt {prompt_idx}; " | |
| "skipping CPR/CMD for this prompt." | |
| ) | |
| return None | |
| # Node importance within the graph | |
| sorted_nodes = calculate_graph_importance(pipeline.attribution_graph, full_graph) | |
| total_edges = full_graph.number_of_edges() | |
| k_grid = [0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0] | |
| f_values = [] | |
| actual_k_values = [] | |
| # Embeddings/output are always kept | |
| always_kept_nodes = [n for n in full_graph.nodes() if not n.startswith("feat_")] | |
| logger.info("Computing faithfulness curve...") | |
| for k in k_grid: | |
| target_edge_count = int(k * total_edges) | |
| current_circuit_nodes = list(always_kept_nodes) | |
| current_feature_tuples = [] | |
| for node, _ in sorted_nodes: | |
| current_edge_count = get_edges_count(full_graph, current_circuit_nodes) | |
| if current_edge_count >= target_edge_count and len(current_feature_tuples) > 0: | |
| break | |
| current_circuit_nodes.append(node) | |
| parts = node.split("_") | |
| l = int(parts[1][1:]) | |
| f = int(parts[3][1:]) | |
| current_feature_tuples.append((l, f)) | |
| actual_edges = get_edges_count(full_graph, current_circuit_nodes) | |
| actual_k = actual_edges / total_edges if total_edges > 0 else 0.0 | |
| actual_k_values.append(actual_k) | |
| # Complement = universe \ current features | |
| current_set = set(current_feature_tuples) | |
| complement_set = [ft for ft in universe_features if ft not in current_set] | |
| if not complement_set: | |
| m_Ck = m_N | |
| else: | |
| res = pipeline.perturbation_experiments.feature_set_ablation_experiment( | |
| prompt, | |
| feature_set=complement_set, | |
| intervention_strength=1.0, | |
| target_token_id=target_token_id | |
| ) | |
| m_Ck = res["ablated_logit"] | |
| if not math.isfinite(m_Ck): | |
| logger.warning( | |
| f"Non-finite m_Ck={m_Ck} for k={k:.4f} on prompt {prompt_idx}; " | |
| "skipping this k point." | |
| ) | |
| continue | |
| if abs(m_N - m_empty) < 1e-6: | |
| f_k = 0.0 | |
| else: | |
| raw_f = (m_Ck - m_empty) / (m_N - m_empty) | |
| f_k = max(0.0, min(1.0, raw_f)) | |
| f_values.append(f_k) | |
| if not actual_k_values or not f_values: | |
| logger.warning(f"No valid k-points for prompt {prompt_idx}; skipping.") | |
| return None | |
| pairs = sorted(zip(actual_k_values, f_values), key=lambda x: x[0]) | |
| sorted_k = [p[0] for p in pairs] | |
| sorted_f = [p[1] for p in pairs] | |
| if sorted_k[0] > 0.0: | |
| sorted_k.insert(0, 0.0) | |
| sorted_f.insert(0, 0.0) | |
| if sorted_k[-1] < 1.0: | |
| last_f = sorted_f[-1] | |
| sorted_k.append(1.0) | |
| sorted_f.append(last_f) | |
| cpr = compute_cpr(sorted_k, sorted_f) | |
| cmd = compute_cmd(sorted_k, sorted_f) | |
| logger.info(f"Result: CPR={cpr:.4f}, CMD={cmd:.4f}") | |
| return { | |
| "prompt": prompt, | |
| "target_token": pipeline.tokenizer.decode([target_token_id]), | |
| "m_N": m_N, | |
| "m_empty": m_empty, | |
| "curve_k": sorted_k, | |
| "curve_f": sorted_f, | |
| "CPR": cpr, | |
| "CMD": cmd | |
| } | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--output", type=str, default="circuit_analysis/results/cpr_cmd_results.json") | |
| args = parser.parse_args() | |
| # Initialize Pipeline | |
| config = AttributionGraphConfig( | |
| model_path="models/OLMo-2-1124-7B", # Adjust relative path if needed | |
| n_features_per_layer=512, # Back to 512 due to memory constraints | |
| # We want a fairly rich graph to start with, so we can prune it down | |
| graph_feature_activation_threshold=0.01, | |
| graph_edge_weight_threshold=0.003, # Lower threshold for more edges (prev: 0.005) | |
| graph_max_features_per_layer=40, # Increased from 24 (prev: 100 was too slow) | |
| graph_max_edges_per_node=20, # Increased from 12 (prev: 50 was too slow) | |
| # intervention_strength defaults to 5.0 in AttributionGraphConfig, which was working better | |
| intervention_strength=1.0, | |
| ) | |
| # Check model path | |
| if not os.path.exists(config.model_path): | |
| # Try absolute python3 circuit_analysis/calculate_cpr_cmd.pypath or relative to script | |
| root_path = Path(__file__).resolve().parent.parent | |
| possible_path = root_path / "models" / "OLMo-2-1124-7B" | |
| if possible_path.exists(): | |
| config.model_path = str(possible_path) | |
| else: | |
| # Try the one in current dir? | |
| pass | |
| pipeline = AttributionGraphsPipeline(config) | |
| # Load CLT | |
| clt_path = "circuit_analysis/models/clt_model.pth" | |
| if not os.path.exists(clt_path): | |
| # Try full path | |
| clt_path = str(Path(__file__).resolve().parent / "models" / "clt_model.pth") | |
| if os.path.exists(clt_path): | |
| pipeline.load_clt(clt_path) | |
| else: | |
| logger.error(f"CLT model not found at {clt_path}. Please train it first.") | |
| return | |
| results = [] | |
| for i in range(len(ANALYSIS_PROMPTS)): | |
| try: | |
| res = run_cpr_cmd_analysis(pipeline, i) | |
| if res: | |
| results.append(res) | |
| except Exception as e: | |
| logger.error(f"Failed prompt {i}: {e}", exc_info=True) | |
| # Average CPR/CMD | |
| if results: | |
| avg_cpr = np.mean([r['CPR'] for r in results]) | |
| avg_cmd = np.mean([r['CMD'] for r in results]) | |
| else: | |
| avg_cpr = 0.0 | |
| avg_cmd = 0.0 | |
| final_output = { | |
| "results": results, | |
| "average_CPR": avg_cpr, | |
| "average_CMD": avg_cmd | |
| } | |
| # Save | |
| os.makedirs(os.path.dirname(args.output), exist_ok=True) | |
| with open(args.output, 'w') as f: | |
| json.dump(final_output, f, indent=2) | |
| print(f"\n\nFinal Average CPR: {avg_cpr:.4f}") | |
| print(f"Final Average CMD: {avg_cmd:.4f}") | |
| print(f"Results saved to {args.output}") | |
| if __name__ == "__main__": | |
| main() | |