Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Visualize the aggregate metrics produced by offline_circuit_metrics.py | |
| both overall and per prompt. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| from typing import Dict, Any | |
| from textwrap import fill | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import seaborn as sns | |
| DEFAULT_RESULTS = Path(__file__).parent / "results" / "offline_circuit_metrics.json" | |
| DEFAULT_CPR_CMD = Path(__file__).parent / "results" / "cpr_cmd_results.json" | |
| # Save directly to the paper figures directory | |
| DEFAULT_FIG = Path(__file__).parent.parent / "writing" / "ELIA__EACL_2026_System_Demonstrations_" / "figures" / "offline_circuit_metrics_combined.png" | |
| def _load_payload(path: Path) -> Dict[str, Any]: | |
| with open(path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| if "aggregate_summary" not in data or "per_prompt" not in data: | |
| raise ValueError(f"Expected 'aggregate_summary' and 'per_prompt' in {path}") | |
| return data | |
| def _configure_plot_style() -> None: | |
| sns.set_theme(style="ticks", palette="colorblind") | |
| plt.rcParams["font.family"] = "sans-serif" | |
| plt.rcParams["font.sans-serif"] = "Arial" | |
| plt.rcParams["axes.labelweight"] = "normal" | |
| plt.rcParams["axes.titleweight"] = "bold" | |
| plt.rcParams["figure.titleweight"] = "bold" | |
| plt.rcParams["savefig.dpi"] = 300 | |
| plt.rcParams["figure.facecolor"] = "white" | |
| plt.rcParams["axes.facecolor"] = "white" | |
| plt.rcParams["grid.alpha"] = 0.2 | |
| plt.rcParams["axes.spines.top"] = False | |
| plt.rcParams["axes.spines.right"] = False | |
| def _load_cpr_cmd(path: Path) -> Dict[str, Any]: | |
| """Load CPR/CMD results if available.""" | |
| if not path.exists(): | |
| return None | |
| try: | |
| with open(path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| return data | |
| except Exception as e: | |
| print(f"Warning: Could not load CPR/CMD results from {path}: {e}") | |
| return None | |
| def plot_combined(summary: Dict[str, Any], per_prompt: Dict[str, Any], output_path: Path, cpr_cmd_data: Dict[str, Any] = None): | |
| _configure_plot_style() | |
| # Prepare data | |
| labels = [r"$\mathbf{Aggregate}$"] | |
| targeted_vals = [summary["targeted"]["avg_abs_probability_change"]] | |
| random_vals = [summary["random_baseline"]["avg_abs_probability_change"]] | |
| path_vals = [summary["path"]["avg_abs_probability_change"]] | |
| random_path_vals = [summary["random_path_baseline"]["avg_abs_probability_change"]] | |
| # Prepare CPR data if available | |
| cpr_vals = [] | |
| if cpr_cmd_data: | |
| # Get average CPR for aggregate | |
| results = cpr_cmd_data.get("results", []) | |
| if results: | |
| avg_cpr = cpr_cmd_data.get("average_CPR", 0.0) | |
| cpr_vals.append(avg_cpr) | |
| # Map prompts to CPR values | |
| prompt_to_cpr = {} | |
| for result in results: | |
| prompt_text = result.get("prompt", "") | |
| prompt_to_cpr[prompt_text] = result.get("CPR", 0.0) | |
| for key, data in per_prompt.items(): | |
| # Clean up prompt label for display (first 5 words or so) | |
| prompt_text = data.get("prompt", key) | |
| labels.append(prompt_text) | |
| stats = data.get("summary_statistics", {}) | |
| targeted_vals.append(stats.get("targeted", {}).get("avg_abs_probability_change", 0.0)) | |
| random_vals.append(stats.get("random_baseline", {}).get("avg_abs_probability_change", 0.0)) | |
| path_vals.append(stats.get("path", {}).get("avg_abs_probability_change", 0.0)) | |
| random_path_vals.append(stats.get("random_path_baseline", {}).get("avg_abs_probability_change", 0.0)) | |
| # Add CPR for this prompt if available | |
| if cpr_cmd_data and prompt_text in prompt_to_cpr: | |
| cpr_vals.append(prompt_to_cpr[prompt_text]) | |
| elif cpr_cmd_data: | |
| # If CPR data exists but this prompt isn't in it, add zero | |
| cpr_vals.append(0.0) | |
| x = np.arange(len(labels)) | |
| width = 0.2 | |
| # Use a aspect ratio that fits well in a paper (e.g. wide enough for column) | |
| fig, ax = plt.subplots(figsize=(10, 6), constrained_layout=True) | |
| # Create second y-axis for CPR if data is available | |
| ax2 = None | |
| if cpr_cmd_data and cpr_vals: | |
| ax2 = ax.twinx() | |
| # Color palette - using specific indices from colorblind to ensure contrast | |
| # 0: Blue, 1: Orange, 2: Green, 3: Red, 4: Purple, etc. | |
| palette = sns.color_palette("colorblind") | |
| c_target = palette[0] # Blue | |
| c_random = palette[7] # Grey-ish or distinct | |
| c_path = palette[2] # Green | |
| c_path_rnd = palette[3] # Red | |
| # Plot bars | |
| features_targeted = ax.bar(x - width * 1.5, targeted_vals, width, label="Targeted Features", color=c_target) | |
| features_random = ax.bar(x - width/2, random_vals, width, label="Random Features", color=c_random, alpha=0.7) | |
| paths_targeted = ax.bar(x + width/2, path_vals, width, label="Traced Circuits", color=c_path) | |
| paths_random = ax.bar(x + width * 1.5, random_path_vals, width, label="Random Path Baseline", color=c_path_rnd, alpha=0.7) | |
| # Add value labels on top of bars (only if they are significant enough to not clutter) | |
| def autolabel(rects): | |
| for rect in rects: | |
| height = rect.get_height() | |
| # Threshold logic: Only skip if truly tiny (effectively zero) | |
| if height > 0.01: | |
| ax.annotate( | |
| f"{height:.2f}", | |
| xy=(rect.get_x() + rect.get_width() / 2, height), | |
| xytext=(0, 3), | |
| textcoords="offset points", | |
| ha="center", | |
| va="bottom", | |
| fontsize=14, | |
| fontweight="normal", | |
| color="black" | |
| ) | |
| autolabel(features_targeted) | |
| autolabel(features_random) | |
| autolabel(paths_targeted) | |
| autolabel(paths_random) | |
| # Plot CPR on second axis if available | |
| if ax2 and cpr_vals: | |
| # Plot as line with markers | |
| line1 = ax2.plot(x, cpr_vals, marker='o', linestyle='--', linewidth=2, | |
| markersize=8, color='purple', label='CPR', zorder=5) | |
| ax2.set_ylabel("CPR", fontsize=16, fontweight="normal", color='black') | |
| ax2.tick_params(axis='y', labelcolor='black', labelsize=14) | |
| ax2.set_ylim(0, 1.1) # CPR is in [0,1] | |
| # Add value labels for CPR (below the markers) | |
| for i, cpr_val in enumerate(cpr_vals): | |
| if cpr_val > 0.01: | |
| ax2.annotate(f'{cpr_val:.2f}', xy=(i, cpr_val), xytext=(-20, -5), | |
| textcoords='offset points', fontsize=11, color='purple', | |
| fontweight='bold', ha='center') | |
| # Add CPR to legend | |
| lines1, labels1 = ax.get_legend_handles_labels() | |
| lines2, labels2 = ax2.get_legend_handles_labels() | |
| ax.legend(lines1 + lines2, labels1 + labels2, loc="upper left", ncol=3, | |
| frameon=True, framealpha=0.9, edgecolor="white", fontsize=12) | |
| else: | |
| # Original legend if no CPR | |
| ax.legend(loc="upper left", ncol=2, frameon=True, framealpha=0.9, edgecolor="white", fontsize=14) | |
| ax.set_ylabel("Avg. |Probability Change| (|Δp|)", fontsize=16, fontweight="normal") | |
| ax.set_xticks(x) | |
| # Wrap labels nicely (but preserve LaTeX formatting for Aggregate) | |
| wrapped_labels = [] | |
| for label in labels: | |
| if r"$\mathbf{Aggregate}$" in label: | |
| wrapped_labels.append(label) | |
| else: | |
| wrapped_labels.append(fill(label, 20)) | |
| ax.set_xticklabels(wrapped_labels, rotation=0, ha="center", fontsize=14) | |
| # Add subtle grid | |
| ax.grid(axis='y', linestyle='--', alpha=0.3) | |
| # Adjust y-limit to give some headroom for labels | |
| y_max = max(max(targeted_vals), max(path_vals), max(random_vals), max(random_path_vals)) | |
| ax.set_ylim(0, y_max * 1.30) | |
| # Ensure output directory exists | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| fig.savefig(output_path, dpi=300) | |
| plt.close(fig) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Plot offline attribution metrics.") | |
| parser.add_argument( | |
| "--input", | |
| type=str, | |
| default=str(DEFAULT_RESULTS), | |
| help="Path to offline_circuit_metrics.json" | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| type=str, | |
| default=str(DEFAULT_FIG), | |
| help="Path to save the per-prompt figure (PNG)." | |
| ) | |
| parser.add_argument( | |
| "--cpr-cmd", | |
| type=str, | |
| default=str(DEFAULT_CPR_CMD), | |
| help="Path to CPR/CMD results JSON file (optional)." | |
| ) | |
| args = parser.parse_args() | |
| if not Path(args.input).exists(): | |
| print(f"Error: Input file {args.input} not found. Please run offline_circuit_metrics.py first.") | |
| return | |
| payload = _load_payload(Path(args.input)) | |
| summary = payload["aggregate_summary"] | |
| per_prompt = payload["per_prompt"] | |
| # Load CPR/CMD data if available | |
| cpr_cmd_data = _load_cpr_cmd(Path(args.cpr_cmd)) | |
| plot_combined(summary, per_prompt, Path(args.output), cpr_cmd_data) | |
| print(f"Saved combined plot to {args.output}") | |
| if __name__ == "__main__": | |
| main() | |