import streamlit as st import inseq import torch import os from transformers import AutoTokenizer, AutoModelForCausalLM import json import requests from bs4 import BeautifulSoup import pandas as pd import numpy as np from inseq.models.huggingface_model import HuggingfaceDecoderOnlyModel import base64 from io import BytesIO from PIL import Image import plotly.graph_objects as go import re import markdown from utilities.localization import tr import faiss from sentence_transformers import SentenceTransformer, util from sentence_splitter import SentenceSplitter import html from utilities.utils import init_qwen_api from utilities.feedback_survey import display_attribution_feedback from thefuzz import process, fuzz import gc import time import sys from pathlib import Path # A dictionary to map method names to translation keys. METHOD_DESC_KEYS = { "integrated_gradients": "desc_integrated_gradients", "occlusion": "desc_occlusion", "saliency": "desc_saliency" } # Configuration for the influence tracer. sys.path.append(str(Path(__file__).resolve().parent.parent)) INDEX_DIR = os.path.join("influence_tracer", "influence_tracer_data") INDEX_PATH = os.path.join(INDEX_DIR, "dolma_index_multi.faiss") MAPPING_PATH = os.path.join(INDEX_DIR, "dolma_mapping_multi.json") TRACER_MODEL_NAME = 'paraphrase-multilingual-mpnet-base-v2' class CachedAttribution: # A mock object to mimic inseq's Attribution object for cached results. def __init__(self, html_content): self.html_content = html_content def show(self, display=False, return_html=True): return self.html_content def load_all_attribution_models(): # Loads all the attribution models. try: # Set the device to MPS, CUDA, or CPU. device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" # Path to the local model. model_path = "./models/OLMo-2-1124-7B" hf_token = os.environ.get("HF_TOKEN") # Load tokenizer and model. tokenizer = AutoTokenizer.from_pretrained(model_path, token=hf_token, trust_remote_code=True) tokenizer.model_max_length = 512 # Load the model with half precision to save memory. base_model = AutoModelForCausalLM.from_pretrained( model_path, token=hf_token, torch_dtype=torch.float16, low_cpu_mem_usage=True, trust_remote_code=True ) # Move the model to the selected device. base_model = base_model.to(device) # Add missing special tokens if necessary. if tokenizer.bos_token is None: tokenizer.add_special_tokens({'bos_token': ''}) base_model.resize_token_embeddings(len(tokenizer)) # Patch the model config. if base_model.config.bos_token_id is None: base_model.config.bos_token_id = tokenizer.bos_token_id attribution_models = {} # Set up the Integrated Gradients model. attribution_models["integrated_gradients"] = HuggingfaceDecoderOnlyModel( model=base_model, tokenizer=tokenizer, device=device, attribution_method="integrated_gradients", attribution_kwargs={"n_steps": 10} ) # Set up the Occlusion model. attribution_models["occlusion"] = HuggingfaceDecoderOnlyModel( model=base_model, tokenizer=tokenizer, device=device, attribution_method="occlusion" ) # Set up the Saliency model. attribution_models["saliency"] = HuggingfaceDecoderOnlyModel( model=base_model, tokenizer=tokenizer, device=device, attribution_method="saliency" ) return attribution_models, tokenizer, base_model, device except Exception as e: st.error(f"Error loading models: {str(e)}") return None, None, None, None def load_influence_tracer_data(): # Loads the data needed for the influence tracer. if not os.path.exists(INDEX_PATH) or not os.path.exists(MAPPING_PATH): return None, None, None index = faiss.read_index(INDEX_PATH) with open(MAPPING_PATH, 'r', encoding='utf-8') as f: mapping = json.load(f) device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" model = SentenceTransformer(TRACER_MODEL_NAME, device=device) return index, mapping, model @st.cache_data(persist=True) def get_influential_docs(text_to_trace: str, lang: str): # Finds influential documents from the training data for a given text. faiss_index, doc_mapping, tracer_model = load_influence_tracer_data() if not faiss_index: return [] # Get the embedding for the input text. doc_embedding = tracer_model.encode([text_to_trace], convert_to_numpy=True, normalize_embeddings=True) # Search the FAISS index for the top k documents. k = 3 similarities, indices = faiss_index.search(doc_embedding.astype('float32'), k) # Find the most similar sentence in each influential document. results = [] query_embedding = tracer_model.encode([text_to_trace], normalize_embeddings=True) for i in range(k): doc_id = str(indices[0][i]) if doc_id in doc_mapping: doc_info = doc_mapping[doc_id] file_path = os.path.join("influence_tracer", "dolma_dataset_sample_1.6v", doc_info['file']) try: full_doc_text = "" with open(file_path, 'r', encoding='utf-8') as f: for line in f: try: line_data = json.loads(line) line_text = line_data.get('text', '') # Use fuzzy matching to find the text snippet. if fuzz.partial_ratio(doc_info['text_snippet'], line_text) > 95: full_doc_text = line_text break except json.JSONDecodeError: continue # Skip if the document text wasn't found. if not full_doc_text: print(f"Warning: Could not find document snippet for doc {doc_id} in {file_path}. Skipping.") continue # Find the most similar sentence in the document. splitter = SentenceSplitter(language=lang) sentences = splitter.split(text=full_doc_text) if not sentences: sentences = [full_doc_text] # Set a batch size to avoid memory issues. sentence_embeddings = tracer_model.encode(sentences, batch_size=64, show_progress_bar=False, normalize_embeddings=True) cos_scores = util.pytorch_cos_sim(query_embedding, sentence_embeddings)[0] best_sentence_idx = torch.argmax(cos_scores).item() most_similar_sentence = sentences[best_sentence_idx] results.append({ 'id': doc_id, 'file': doc_info['file'], 'source': doc_info['source'], 'text': full_doc_text, 'similarity': float(similarities[0][i]), 'highlight_sentence': str(most_similar_sentence) }) except (IOError, KeyError) as e: print(f"Could not retrieve full text for doc {doc_id}: {e}") continue return results # --- Qwen API for Explanations --- @st.cache_data(persist=True) def _cached_explain_heatmap(api_config, img_base64, csv_text, structured_prompt): # Makes a cached API call to Qwen to get an explanation for a heatmap. headers = { "Authorization": f"Bearer {api_config['api_key']}", "Content-Type": "application/json" } content = [{"type": "text", "text": structured_prompt}] if img_base64: content.append({ "type": "image_url", "image_url": { "url": f"data:image/png;base64,{img_base64}" } }) data = { "model": api_config["model"], "messages": [ { "role": "user", "content": content } ], "max_tokens": 1200, "temperature": 0.2, "top_p": 0.95, "seed": 42 } response = requests.post( f"{api_config['api_endpoint']}/chat/completions", headers=headers, json=data, timeout=300 ) # Raise an exception if the API call fails. response.raise_for_status() result = response.json() return result["choices"][0]["message"]["content"] @st.cache_data(persist=True) def generate_all_attribution_analyses(_attribution_models, _tokenizer, _base_model, _device, prompt, max_tokens, force_exact_num_tokens=False): # Generates text and runs attribution analysis for all methods. # Generate the text first. inputs = _tokenizer(prompt, return_tensors="pt").to(_device) generation_args = { 'max_new_tokens': max_tokens, 'do_sample': False } if force_exact_num_tokens: generation_args['min_new_tokens'] = max_tokens generated_ids = _base_model.generate( inputs.input_ids, **generation_args ) generated_text = _tokenizer.decode(generated_ids[0], skip_special_tokens=True) # Run attribution analysis for all methods. all_attributions = {} methods = ["integrated_gradients", "occlusion", "saliency"] for method in methods: attributions = _attribution_models[method].attribute( input_texts=prompt, generated_texts=generated_text ) all_attributions[method] = attributions return generated_text, all_attributions def explain_heatmap_with_csv_data(api_config, image_buffer, csv_data, context_prompt, generated_text, method_name="Attribution"): # Generates an explanation for a heatmap using the Qwen API. try: # Convert the image to base64. img_base64 = None if image_buffer: image_buffer.seek(0) image = Image.open(image_buffer) buffered = BytesIO() image.save(buffered, format="PNG") img_base64 = base64.b64encode(buffered.getvalue()).decode() # Clean the dataframe to handle duplicates. df_clean = csv_data.copy() cols = pd.Series(df_clean.columns) if cols.duplicated().any(): for dup in cols[cols.duplicated()].unique(): dup_indices = cols[cols == dup].index.values new_names = [f"{dup} ({i+1})" for i in range(len(dup_indices))] cols[dup_indices] = new_names df_clean.columns = cols if df_clean.index.has_duplicates: counts = {} new_index = list(df_clean.index) duplicated_indices = df_clean.index[df_clean.index.duplicated(keep=False)] for i, idx in enumerate(df_clean.index): if idx in duplicated_indices: counts[idx] = counts.get(idx, 0) + 1 new_index[i] = f"{idx} ({counts[idx]})" df_clean.index = new_index # --- Rule-Based Analysis --- unstacked = df_clean.unstack() unstacked.index = unstacked.index.map('{0[1]} -> {0[0]}'.format) # Get the top 5 individual scores. top_5_individual = unstacked.abs().nlargest(5).sort_index() top_individual_text_lines = ["\n### Top 5 Strongest Individual Connections:"] for label in top_5_individual.index: score = unstacked[label] top_individual_text_lines.append(f"- **{label}**: score {score:.2f}") # Get the top 5 average input scores. avg_input_scores = df_clean.mean(axis=1) top_5_average = avg_input_scores.abs().nlargest(5).sort_index() top_average_text_lines = ["\n### Top 5 Most Influential Input Tokens (on average over the whole generation):"] for input_token in top_5_average.index: score = avg_input_scores[input_token] top_average_text_lines.append(f"- **'{input_token}'**: average score {score:.2f}") # Get the top output token sources. top_output_text_lines = [] if not df_clean.empty: avg_output_scores = df_clean.mean(axis=0) top_3_output = avg_output_scores.abs().nlargest(min(3, len(df_clean.columns))).sort_index() if not top_3_output.empty: top_output_text_lines.append("\n### Top 3 Most Influenced Generated Tokens:") for output_token in top_3_output.index: # Find which input tokens influenced this output token the most. top_sources_for_output = df_clean[output_token].abs().nlargest(min(2, len(df_clean.index))).sort_index().index.tolist() if top_sources_for_output: top_output_text_lines.append(f"- **'{output_token}'** was most influenced by **'{', '.join(top_sources_for_output)}'**.") data_text_for_llm = "\n".join(top_individual_text_lines + top_average_text_lines + top_output_text_lines) # Get method-specific context from the translation files. desc_key = METHOD_DESC_KEYS.get(method_name, "unsupported_method_desc") method_context = tr(desc_key) # Format the instruction for the LLM. instruction_p1 = tr('instruction_part_1_desc').format(method_name=method_name.replace('_', ' ').title()) # Create the prompt for the LLM. structured_prompt = f"""{tr('ai_expert_intro')} ## {tr('analysis_details')} - **{tr('method_being_used')}** {method_name.replace('_', ' ').title()} - **{tr('prompt_analyzed')}** "{context_prompt}" - **{tr('full_generated_text')}** "{generated_text}" ## {tr('method_specific_context')} {method_context} ## {tr('instructions_for_analysis')} {tr('instruction_part_1_header')} {instruction_p1} {tr('instruction_synthesis_header')} {tr('instruction_synthesis_desc')} {tr('instruction_color_coding')} ## {tr('data_section_header')} {data_text_for_llm} {tr('begin_analysis_now')}""" # Call the cached function to get the explanation. explanation = _cached_explain_heatmap(api_config, img_base64, data_text_for_llm, structured_prompt) return explanation except Exception as e: # Catch errors from data prep or the API call. st.error(f"Error generating AI explanation: {str(e)}") return tr("unable_to_generate_explanation") # --- Faithfulness Verification --- @st.cache_data(persist=True) def _cached_extract_claims_from_explanation(api_config, explanation_text, analysis_method): # Makes a cached API call to Qwen to get claims from an explanation. headers = {"Authorization": f"Bearer {api_config['api_key']}", "Content-Type": "application/json"} # Dynamically set claim types based on the analysis method. claim_types_details = tr("claim_extraction_prompt_types_details") claim_extraction_prompt = f"""{tr('claim_extraction_prompt_header')} {tr('claim_extraction_prompt_instruction')} {tr('claim_extraction_prompt_context_header').format(analysis_method=analysis_method, context=analysis_method)} {tr('claim_extraction_prompt_types_header')} {claim_types_details} {tr('claim_extraction_prompt_example_header')} {tr('claim_extraction_prompt_example_explanation')} {tr('claim_extraction_prompt_example_json')} {tr('claim_extraction_prompt_analyze_header')} "{explanation_text}" {tr('claim_extraction_prompt_instruction_footer')} """ data = { "model": api_config["model"], "messages": [ { "role": "user", "content": [{"type": "text", "text": claim_extraction_prompt}] } ], "max_tokens": 1500, "temperature": 0.0, # Set to 0 for deterministic output. "seed": 42 } response = requests.post( f"{api_config['api_endpoint']}/chat/completions", headers=headers, json=data, timeout=300 ) response.raise_for_status() claims_text = response.json()["choices"][0]["message"]["content"] try: # The response might be inside a markdown code block, so we try to extract it. if '```json' in claims_text: claims_text = re.search(r'```json\n(.*?)\n```', claims_text, re.DOTALL).group(1) # Parse the JSON string into a Python list. return json.loads(claims_text) except (AttributeError, json.JSONDecodeError): return [] @st.cache_data(persist=True) def _cached_verify_token_justification(api_config, analysis_method, input_prompt, generated_text, token, justification): # Uses an LLM to verify if a justification for a token's importance is sound. headers = {"Authorization": f"Bearer {api_config['api_key']}", "Content-Type": "application/json"} verification_prompt = f"""{tr('justification_verification_prompt_header')} {tr('justification_verification_prompt_crucial_rule')} {tr('justification_verification_prompt_token_location')} {tr('justification_verification_prompt_special_tokens')} {tr('justification_verification_prompt_evaluating_justifications')} {tr('justification_verification_prompt_linguistic_context')} {tr('justification_verification_prompt_collective_reasoning')} **Analysis Method:** {analysis_method} **Input Prompt:** "{input_prompt}" **Generated Text:** "{generated_text}" **Token in Question:** "{token}" **Provided Justification:** "{justification}" {tr('justification_verification_prompt_task_header')} {tr('justification_verification_prompt_task_instruction')} {tr('justification_verification_prompt_json_instruction')} {tr('justification_verification_prompt_footer')} """ data = { "model": "qwen2.5-vl-72b-instruct", "messages": [{"role": "user", "content": verification_prompt}], "max_tokens": 400, "temperature": 0.0, "seed": 42, "response_format": {"type": "json_object"} } response = requests.post( f"{api_config['api_endpoint']}/chat/completions", headers=headers, json=data, timeout=300 ) response.raise_for_status() try: result_json = response.json()["choices"][0]["message"]["content"] return json.loads(result_json) except (json.JSONDecodeError, KeyError): return {"is_verified": False, "reasoning": "Could not parse the semantic justification result."} def verify_claims(claims, analysis_data): # Verifies the extracted claims against the analysis data. verification_results = [] # Pre-calculate thresholds and rankings for efficiency. all_scores_flat = analysis_data['scores_df'].abs().values.flatten() # Average influence of each input token. avg_input_scores_abs = analysis_data['scores_df'].mean(axis=1).abs().sort_values(ascending=False) avg_input_scores_raw = analysis_data['scores_df'].mean(axis=1) # Keep signs for specific value checks # Average influence on each generated token. avg_output_scores = analysis_data['scores_df'].mean(axis=0).abs().sort_values(ascending=False) input_tokens = analysis_data['scores_df'].index.tolist() generated_tokens = analysis_data['scores_df'].columns.tolist() for claim in claims: is_verified = False evidence = "Could not be verified." details = claim.get('details', {}) claim_type = claim.get('claim_type') try: # Clean tokens in the claim's details, as the LLM sometimes includes extra quotes. if 'token' in details and isinstance(details['token'], str): details['token'] = re.sub(r"^\s*['\"]|['\"]\s*$", '', details['token']).strip() if 'tokens' in details and isinstance(details['tokens'], list): details['tokens'] = [re.sub(r"^\s*['\"]|['\"]\s*$", '', t).strip() for t in details['tokens']] if claim_type == 'attribution_claim': tokens_claimed = details.get('tokens', []) qualifier = details.get('qualifier', 'significant') # Default to the lower bar score_type = details.get('score_type', 'peak') # Calculate the correct scores based on the claim's score_type. if score_type == 'average': score_series = analysis_data['scores_df'].abs().mean(axis=1) score_name = "average score" else: # peak # Check both influence GIVEN (input) and RECEIVED (output) # We use fillna(0) to handle cases where a token is not in that axis input_peaks = analysis_data['scores_df'].abs().max(axis=1) output_peaks = analysis_data['scores_df'].abs().max(axis=0) combined_scores = {} all_tokens = set(input_peaks.index) | set(output_peaks.index) for t in all_tokens: s1 = input_peaks.get(t, 0.0) s2 = output_peaks.get(t, 0.0) combined_scores[t] = max(s1, s2) score_series = pd.Series(combined_scores) score_name = "peak score" if score_series.empty: evidence = "No attribution data available to verify claim." else: all_attributions = sorted( [{'token': token, 'attribution': score} for token, score in score_series.items()], key=lambda x: x['attribution'], reverse=True ) max_score = all_attributions[0]['attribution'] if all_attributions else 0 if qualifier == 'high': threshold = 0.70 * max_score threshold_name = "high" else: # 'significant' or default threshold = 0.50 * max_score threshold_name = "significant" token_scores_dict = {item['token'].lower().strip(): item['attribution'] for item in all_attributions} unverified_tokens = [] verified_tokens_details = [] for token in tokens_claimed: # New, more robust matching logic. # First, check for a direct match for specific claims like ', (1)'. token_lower = token.lower().strip() if token_lower in token_scores_dict: matching_keys = [token_lower] else: # If no direct match, fall back to a generic search for claims like ','. # This finds all instances: ', (1)', ', (2)', etc. matching_keys = [ k for k in token_scores_dict.keys() if re.sub(r'\s\(\d+\)$', '', k).strip() == token_lower ] if not matching_keys: unverified_tokens.append(f"'{token}' (not found in analysis)") continue # Check each matching instance against the threshold. for key in matching_keys: actual_score = token_scores_dict.get(key) if abs(actual_score) < threshold: unverified_tokens.append(f"'{key}' ({score_name}: {abs(actual_score):.2f})") else: verified_tokens_details.append(f"'{key}' ({score_name}: {abs(actual_score):.2f})") is_verified = not unverified_tokens if is_verified: evidence = f"Verified. All claimed tokens passed the {threshold_name} threshold (> {threshold:.2f}). Details: {', '.join(verified_tokens_details)}." else: fail_reason = f"the following did not meet the {threshold_name} threshold (> {threshold:.2f}): {', '.join(unverified_tokens)}" if verified_tokens_details: evidence = f"While some tokens passed ({', '.join(verified_tokens_details)}), {fail_reason}." else: evidence = f"The following did not meet the {threshold_name} threshold (> {threshold:.2f}): {', '.join(unverified_tokens)}." elif claim_type in ['token_justification_claim', 'token_begruendung_anspruch']: token_val = details.get('token') or details.get('tokens') if isinstance(token_val, list): token = ", ".join(map(str, token_val)) else: token = token_val justification = details.get('justification') or details.get('begruendung') input_prompt = analysis_data.get('prompt', '') generated_text = analysis_data.get('generated_text', '') if not all([token, justification, input_prompt, generated_text]): evidence = "Missing data for justification verification (token, justification, or prompt)." else: api_config = init_qwen_api() if api_config: verification = _cached_verify_token_justification(api_config, analysis_data['method'], input_prompt, generated_text, token, justification) is_verified = verification.get('is_verified', False) evidence = verification.get('reasoning', "Failed to get semantic reasoning for justification.") else: is_verified = False evidence = "API key not configured for semantic verification." except Exception as e: evidence = f"An error occurred during verification: {str(e)}" verification_results.append({ 'claim_text': claim.get('claim_text', 'N/A'), 'verified': is_verified, 'evidence': evidence }) return verification_results # --- End Faithfulness Verification --- def create_heatmap_visualization(attributions, method_name="Attribution"): # Creates a heatmap visualization from attribution scores. try: # Get the HTML content from the attributions. html_content = attributions.show(display=False, return_html=True) if not html_content: st.error(tr("error_inseq_no_html").format(method_name=method_name)) return None, None, None, None # Parse the HTML to extract the data table. soup = BeautifulSoup(html_content, 'html.parser') table = soup.find('table') if not table: st.error(tr("error_no_table_in_html").format(method_name=method_name)) return None, None, None, None # A more structured approach to parsing the HTML. header_row_element = table.find('thead') if header_row_element: headers = [th.get_text(strip=True) for th in header_row_element.find_all('th')[1:]] else: # Fallback if no is found. first_row = table.find('tr') if not first_row: st.error(tr("error_table_no_rows").format(method_name=method_name)) return None, None, None, None headers = [th.get_text(strip=True) for th in first_row.find_all('th')[1:]] data_rows = [] row_labels = [] # Find all `` elements and iterate through their rows. table_bodies = table.find_all('tbody') if not table_bodies: # Fallback if no is found. all_trs = table.find_all('tr') data_trs = all_trs[1:] if len(all_trs) > 1 else [] else: data_trs = [] for tbody in table_bodies: data_trs.extend(tbody.find_all('tr')) for tr_element in data_trs: all_cells = tr_element.find_all(['th', 'td']) if not all_cells or len(all_cells) <= 1: continue row_labels.append(all_cells[0].get_text(strip=True)) # Convert text values to float, handling empty strings as 0. row_data = [] for cell in all_cells[1:]: text_val = cell.get_text(strip=True) # Remove non-breaking spaces. clean_text = text_val.replace('\xa0', '').strip() if clean_text: try: row_data.append(float(clean_text)) except ValueError: # Default to 0 if conversion fails. row_data.append(0.0) else: row_data.append(0.0) data_rows.append(row_data) # Create the dataframe from the parsed data. if not data_rows or not data_rows[0]: st.error(tr("error_failed_to_parse_rows").format(method_name=method_name)) return None, None, None, None # --- Make token labels unique for duplicates --- def make_labels_unique(labels): counts = {} new_labels = [] # First, count all occurrences to decide which ones need numbering. label_counts = {label: labels.count(label) for label in set(labels)} for label in labels: if label_counts[label] > 1: counts[label] = counts.get(label, 0) + 1 new_labels.append(f"{label} ({counts[label]})") else: new_labels.append(label) return new_labels unique_row_labels = make_labels_unique(row_labels) unique_headers = make_labels_unique(headers) parsed_df = pd.DataFrame(data_rows, index=unique_row_labels, columns=unique_headers) attribution_scores = parsed_df.values # Clean tokens for display. clean_headers = parsed_df.columns.tolist() clean_row_labels = parsed_df.index.tolist() # Use numerical indices for the heatmap to handle duplicate labels. x_indices = list(range(len(clean_headers))) y_indices = list(range(len(clean_row_labels))) # Prepare custom data for hover labels. custom_data = np.empty(attribution_scores.shape, dtype=object) for i in range(len(clean_row_labels)): for j in range(len(clean_headers)): custom_data[i, j] = (clean_row_labels[i], clean_headers[j]) fig = go.Figure(data=go.Heatmap( z=attribution_scores, x=x_indices, y=y_indices, customdata=custom_data, hovertemplate="Input: %{customdata[0]}
Generated: %{customdata[1]}
Score: %{z:.4f}", colorscale='Plasma', hoverongaps=False, )) fig.update_layout( title=tr('heatmap_title').format(method_name=method_name), xaxis_title=tr('heatmap_xaxis'), yaxis_title=tr('heatmap_yaxis'), xaxis=dict( tickmode='array', tickvals=x_indices, ticktext=clean_headers, tickangle=45 ), yaxis=dict( tickmode='array', tickvals=y_indices, ticktext=clean_row_labels, autorange='reversed' ), height=max(400, len(clean_row_labels) * 30), width=max(600, len(clean_headers) * 50) ) # Save the plot to a buffer. buffer = BytesIO() try: fig.write_image(buffer, format='png', scale=2) buffer.seek(0) except Exception as e: print(f"Warning: Could not generate static image (Kaleido error?): {e}") buffer = None return fig, html_content, buffer, parsed_df except Exception as e: st.error(tr("error_creating_heatmap").format(e=str(e))) return None, None, None, None def start_new_analysis(prompt, max_tokens, enable_explanations): # Clears old results and starts a new analysis. # Clear old results from the session state. keys_to_clear = [ 'generated_text', 'all_attributions' ] for key in keys_to_clear: if key in st.session_state: del st.session_state[key] # Clear any old cached items. for key in list(st.session_state.keys()): if key.startswith('influential_docs_'): del st.session_state[key] # Update the text area with the new prompt. st.session_state.attr_prompt = prompt # Set parameters for the new analysis. st.session_state.run_request = { "prompt": prompt, "max_tokens": max_tokens, "enable_explanations": enable_explanations } def update_cache_with_explanation(prompt, method_name, explanation): cache_file = os.path.join("cache", "cached_attribution_results.json") if not os.path.exists(cache_file): return try: with open(cache_file, "r", encoding="utf-8") as f: cached_data = json.load(f) if prompt in cached_data: if "explanations" not in cached_data[prompt]: cached_data[prompt]["explanations"] = {} cached_data[prompt]["explanations"][method_name] = explanation with open(cache_file, "w", encoding="utf-8") as f: json.dump(cached_data, f, ensure_ascii=False, indent=4) print(f"Saved explanation for {method_name} to cache.") except Exception as e: print(f"Failed to update cache with explanation: {e}") def update_cache_with_faithfulness(prompt, method_name, verification_results): cache_file = os.path.join("cache", "cached_attribution_results.json") if not os.path.exists(cache_file): return try: with open(cache_file, "r", encoding="utf-8") as f: cached_data = json.load(f) if prompt in cached_data: if "faithfulness" not in cached_data[prompt]: cached_data[prompt]["faithfulness"] = {} cached_data[prompt]["faithfulness"][method_name] = verification_results with open(cache_file, "w", encoding="utf-8") as f: json.dump(cached_data, f, ensure_ascii=False, indent=4) print(f"Saved faithfulness for {method_name} to cache.") except Exception as e: print(f"Failed to update cache with faithfulness: {e}") def run_analysis(prompt, max_tokens, enable_explanations, force_exact_num_tokens=False): # Runs the full analysis pipeline. if not prompt.strip(): st.warning(tr('please_enter_prompt_warning')) return # Check for cached results first cache_file = os.path.join("cache", "cached_attribution_results.json") if os.path.exists(cache_file): with open(cache_file, "r", encoding="utf-8") as f: cached_data = json.load(f) if prompt in cached_data: print("Loading full attribution analysis from cache.") cached_result = cached_data[prompt] # Check if influential_docs are missing and update the cache if possible if "influential_docs" not in cached_result: try: print(f"Updating cache for '{prompt}' with missing influence docs...") lang = st.session_state.get('lang', 'en') # This call should hit the Streamlit cache and be fast missing_docs = get_influential_docs(prompt, lang) if missing_docs: cached_result["influential_docs"] = missing_docs # Save updated cache back to file with open(cache_file, "w", encoding="utf-8") as f: json.dump(cached_data, f, ensure_ascii=False, indent=4) print("Cache updated successfully.") except Exception as e: print(f"Could not update cache with influence docs: {e}") # Populate session state from the comprehensive cache st.session_state.generated_text = cached_result["generated_text"] st.session_state.prompt = prompt st.session_state.enable_explanations = enable_explanations st.session_state.qwen_api_config = init_qwen_api() if enable_explanations else None # Reconstruct attribution objects and store explanations/faithfulness reconstructed_attributions = {} for method, data in cached_result["html_contents"].items(): reconstructed_attributions[method] = CachedAttribution(data) # Use a consistent key for caching in session state cache_key_base = f"{method}_{cached_result['generated_text']}" if "explanation" in data: st.session_state[f"explanation_{cache_key_base}"] = data["explanation"] if "faithfulness_results" in data: st.session_state[f"faithfulness_check_{cache_key_base}"] = data["faithfulness_results"] # Load new structured cache if "explanations" in cached_result and method in cached_result["explanations"]: st.session_state[f"explanation_{cache_key_base}"] = cached_result["explanations"][method] if "faithfulness" in cached_result and method in cached_result["faithfulness"]: st.session_state[f"faithfulness_check_{cache_key_base}"] = cached_result["faithfulness"][method] st.session_state.all_attributions = reconstructed_attributions # Store influential docs if "influential_docs" in cached_result: # Use a key that the UI part can check for st.session_state.cached_influential_docs = cached_result["influential_docs"] st.success(tr('analysis_complete_success')) return # If not in cache, check if models exist before trying to load model_path = "./models/OLMo-2-1124-7B" if not os.path.exists(model_path): st.info("This live demo is running in a static environment. Only the pre-cached example prompts are available. Please select an example to view its analysis.") return # Load the models. with st.spinner(tr('loading_models_spinner')): attribution_models, tokenizer, base_model, device = load_all_attribution_models() if not attribution_models: st.error(tr('failed_to_load_models_error')) return st.session_state.qwen_api_config = init_qwen_api() if enable_explanations else None st.session_state.enable_explanations = enable_explanations st.session_state.prompt = prompt # Generate text and attributions. with st.spinner(tr('running_attribution_analysis_spinner')): try: generated_text, all_attributions = generate_all_attribution_analyses( attribution_models, tokenizer, base_model, device, prompt, max_tokens, force_exact_num_tokens=force_exact_num_tokens ) except Exception as e: st.error(f"Error in attribution analysis: {str(e)}") # Let the rest of the function know it failed. generated_text, all_attributions = None, None if not generated_text or not all_attributions: st.error(tr('failed_to_generate_analysis_error')) return # Store the results in the session state. st.session_state.generated_text = generated_text st.session_state.all_attributions = all_attributions # --- New: Save the new result back to the cache --- try: cache_file = os.path.join("cache", "cached_attribution_results.json") os.makedirs("cache", exist_ok=True) # Load existing cache or create new if os.path.exists(cache_file): with open(cache_file, "r", encoding="utf-8") as f: cached_data = json.load(f) else: cached_data = {} # Add new result html_contents = {method: attr.show(display=False, return_html=True) for method, attr in all_attributions.items()} # Also fetch influential docs to cache them lang = st.session_state.get('lang', 'en') docs_to_cache = get_influential_docs(prompt, lang) cached_data[prompt] = { "generated_text": generated_text, "html_contents": html_contents, "influential_docs": docs_to_cache } # Write back to file with open(cache_file, "w", encoding="utf-8") as f: json.dump(cached_data, f, ensure_ascii=False, indent=4) print(f"Saved new analysis for '{prompt}' to cache.") except Exception as e: print(f"Warning: Could not save result to cache file. {e}") # --- End new section --- # Clean up models to free memory. del attribution_models del tokenizer del base_model gc.collect() if device == 'mps': torch.mps.empty_cache() elif device == 'cuda': torch.cuda.empty_cache() st.success(tr('analysis_complete_success')) def show_attribution_analysis(): # Shows the main attribution analysis page. # Add some CSS for icons. st.markdown('', unsafe_allow_html=True) st.markdown(f"

{tr('attr_page_title')}

", unsafe_allow_html=True) st.markdown(f"{tr('attr_page_desc')}", unsafe_allow_html=True) # Check if a new analysis has been requested by the user. if 'run_request' in st.session_state: request = st.session_state.pop('run_request') run_analysis( prompt=request['prompt'], max_tokens=request['max_tokens'], enable_explanations=request['enable_explanations'] ) # Set up the main layout. col1, col2 = st.columns([1, 1]) with col1: st.markdown(f"

{tr('input_header')}

", unsafe_allow_html=True) # Get the current language from the session state. lang = st.session_state.get('lang', 'en') # Example prompts for English and German. example_prompts = { 'en': [ "The capital of France is", "The first person to walk on the moon was", "To be or not to be, that is the", "Once upon a time, in a land far, far away,", "The chemical formula for water is", "A stitch in time saves", "The opposite of hot is", "The main ingredients of a pizza are", "She opened the door and saw" ], 'de': [ "Die Hauptstadt von Frankreich ist", "Die erste Person auf dem Mond war", "Sein oder Nichtsein, das ist hier die", "Es war einmal, in einem weit, weit entfernten Land,", "Die chemische Formel für Wasser ist", "Was du heute kannst besorgen, das verschiebe nicht auf", "Das Gegenteil von heiß ist", "Die Hauptzutaten einer Pizza sind", "Sie öffnete die Tür und sah" ] } st.markdown('** Example Prompts:**', unsafe_allow_html=True) cols = st.columns(3) for i, example in enumerate(example_prompts[lang][:9]): with cols[i % 3]: st.button( example, key=f"example_{i}", use_container_width=True, on_click=start_new_analysis, args=(example, 10, st.session_state.get('enable_explanations', True)) ) # Text input area for the user's prompt. prompt = st.text_area( tr('enter_prompt'), value=st.session_state.get('attr_prompt', ""), height=100, help=tr('enter_prompt_help'), placeholder="Sadly no GPU available. Please select an example above.", disabled=True ) # Slider for the number of tokens to generate. max_tokens = st.slider( tr('max_new_tokens_slider'), min_value=1, max_value=50, value=5, help=tr('max_new_tokens_slider_help'), disabled=True ) # Checkbox to enable or disable AI explanations. enable_explanations = st.checkbox( tr('enable_ai_explanations'), value=True, help=tr('enable_ai_explanations_help') ) # Button to start the analysis. st.button( tr('generate_and_analyze_button'), type="primary", on_click=start_new_analysis, args=(prompt, max_tokens, enable_explanations), disabled=True ) with col2: st.markdown(f"

{tr('output_header')}

", unsafe_allow_html=True) if hasattr(st.session_state, 'generated_text'): st.subheader(tr('generated_text_subheader')) # Extract the generated part of the text. prompt_part = st.session_state.prompt full_text = st.session_state.generated_text generated_part = full_text if full_text.startswith(prompt_part): generated_part = full_text[len(prompt_part):].lstrip() else: # A fallback in case tokenization changes the prompt slightly. generated_part = full_text.replace(prompt_part, "", 1).strip() # Clean up the generated text for display. cleaned_generated_part = re.sub(r'\n{2,}', '\n', generated_part).strip() escaped_generated = html.escape(cleaned_generated_part) escaped_prompt = html.escape(prompt_part) st.markdown(f"""
{tr('input_label')} {escaped_prompt}
{tr('generated_label')} {escaped_generated}
""", unsafe_allow_html=True) # Display the visualizations for each method. if hasattr(st.session_state, 'all_attributions'): st.header(tr('attribution_analysis_results_header')) # Create tabs for each analysis method. tab_titles = [ tr('saliency_tab'), tr('attr_tab'), tr('occlusion_tab') ] tabs = st.tabs(tab_titles) # Define the order of the methods in the tabs. methods = { "saliency": { "tab": tabs[0], "title": tr('saliency_title'), "description": tr('saliency_viz_desc') }, "integrated_gradients": { "tab": tabs[1], "title": tr('attr_title'), "description": tr('attr_viz_desc') }, "occlusion": { "tab": tabs[2], "title": tr('occlusion_title'), "description": tr('occlusion_viz_desc') } } # Generate and display the visualization for each method. for method_name, method_info in methods.items(): with method_info["tab"]: st.subheader(f"{method_info['title']} Analysis") # Generate the heatmap. with st.spinner(tr('creating_viz_spinner').format(method_title=method_info['title'])): heatmap_fig, html_content, heatmap_buffer, scores_df = create_heatmap_visualization( st.session_state.all_attributions[method_name], method_name=method_info['title'] ) if heatmap_fig: st.plotly_chart(heatmap_fig, use_container_width=True) # Add an explanation of how to read the heatmap. explanation_html = f"""

{tr('how_to_read_heatmap')}

""" st.markdown(explanation_html, unsafe_allow_html=True) # Generate an AI explanation for the heatmap. if (st.session_state.get('enable_explanations') and st.session_state.get('qwen_api_config') and heatmap_buffer is not None and scores_df is not None): explanation_cache_key = f"explanation_{method_name}_{st.session_state.generated_text}" # Get the explanation from the cache or generate it. if explanation_cache_key not in st.session_state: with st.spinner(tr('generating_ai_explanations_spinner').format(method_title=method_info['title'])): explanation = explain_heatmap_with_csv_data( st.session_state.qwen_api_config, heatmap_buffer, scores_df, st.session_state.prompt, st.session_state.generated_text, method_name ) st.session_state[explanation_cache_key] = explanation # Update cache file update_cache_with_explanation(st.session_state.prompt, method_name, explanation) explanation = st.session_state.get(explanation_cache_key) if explanation and not explanation.startswith("Error:"): simple_desc = tr(METHOD_DESC_KEYS.get(method_name, "unsupported_method_desc")) st.markdown(f"#### {tr('what_this_method_shows')}") st.markdown(f"""

{simple_desc}

""", unsafe_allow_html=True) html_explanation = markdown.markdown(explanation) st.markdown(f"#### {tr('ai_generated_analysis')}") st.markdown(f"""
{html_explanation}
""", unsafe_allow_html=True) # Faithfulness Check Expander with st.expander(tr('faithfulness_check_expander')): st.markdown(tr('faithfulness_check_explanation_html'), unsafe_allow_html=True) with st.spinner(tr('running_faithfulness_check_spinner')): try: # Use a cache key to avoid re-running the check unnecessarily. check_cache_key = f"faithfulness_check_{method_name}_{st.session_state.generated_text}" if check_cache_key not in st.session_state: claims = _cached_extract_claims_from_explanation( st.session_state.qwen_api_config, explanation, method_name ) if claims: analysis_data = { 'scores_df': scores_df, 'method': method_name, 'prompt': st.session_state.prompt, 'generated_text': st.session_state.generated_text } verification_results = verify_claims(claims, analysis_data) st.session_state[check_cache_key] = verification_results # Update cache file update_cache_with_faithfulness(st.session_state.prompt, method_name, verification_results) else: st.session_state[check_cache_key] = [] verification_results = st.session_state[check_cache_key] if verification_results: st.markdown(f"
{tr('faithfulness_check_results_header')}
", unsafe_allow_html=True) for result in verification_results: status_text = tr('verified_status') if result['verified'] else tr('contradicted_status') st.markdown(f"""

{tr('claim_label')}: "{result['claim_text']}"

{tr('status_label')}: {status_text}

{tr('evidence_label')}: {result['evidence']}

""", unsafe_allow_html=True) else: st.info(tr('no_verifiable_claims_info')) except Exception as e: st.error(tr('faithfulness_check_error').format(e=str(e))) # Add download buttons for the results. st.subheader(tr("download_results_subheader")) col1, col2 = st.columns(2) with col1: if html_content: st.download_button( label=tr("download_html_button").format(method_title=method_info['title']), data=html_content, file_name=f"{method_name}_analysis.html", mime="text/html", key=f"html_{method_name}" ) if scores_df is not None: st.download_button( label=tr("download_csv_button"), data=scores_df.to_csv().encode('utf-8'), file_name=f"{method_name}_scores.csv", mime="text/csv", key=f"csv_raw_{method_name}" ) with col2: if heatmap_fig: img_bytes = heatmap_fig.to_image(format="png", scale=2) st.download_button( label=tr("download_png_button").format(method_title=method_info['title']), data=img_bytes, file_name=f"{method_name}_heatmap.png", mime="image/png", key=f"png_{method_name}" ) # Display the influence tracer section. st.markdown("---") st.markdown(f'

{tr("influence_tracer_title")}

', unsafe_allow_html=True) st.markdown(f"
{tr('influence_tracer_desc')}
", unsafe_allow_html=True) # Add a visual explanation of cosine similarity. # Get translated text. sentence_a = tr('influence_example_sentence_a') sentence_b = tr('influence_example_sentence_b') # Create the SVG for the diagram. svg_code = f""" θ Vector A {sentence_a} Vector B {sentence_b} """ # Encode the SVG to base64. encoded_svg = base64.b64encode(svg_code.encode("utf-8")).decode("utf-8") image_uri = f"data:image/svg+xml;base64,{encoded_svg}" # Display the explanation and diagram. st.markdown(f"""

{tr('how_influence_is_found_header')}

{tr('how_influence_is_found_desc')}

{tr('influence_step_1_title')}: {tr('influence_step_1_desc')}

{tr('influence_step_2_title')}: {tr('influence_step_2_desc')}

{tr('influence_step_3_title')}: {tr('influence_step_3_desc')}

Cosine Similarity Diagram
""", unsafe_allow_html=True) st.write("") if hasattr(st.session_state, 'generated_text'): # First, check if influential docs are available in the cache from session_state if 'cached_influential_docs' in st.session_state: influential_docs = st.session_state.pop('cached_influential_docs') # Use and remove else: with st.spinner(tr('running_influence_trace_spinner')): lang = st.session_state.get('lang', 'en') influential_docs = get_influential_docs(st.session_state.prompt, lang) # Display the results. if influential_docs: st.markdown(f"#### {tr('top_influential_docs_header').format(num_docs=len(influential_docs))}") # A nice visualization for the influential documents. for i, doc in enumerate(influential_docs): colors = ["#A78BFA", "#7F9CF5", "#6EE7B7", "#FBBF24", "#F472B6"] card_color = colors[i % len(colors)] full_text = doc['text'] highlight_sentence = doc.get('highlight_sentence', '') highlighted_html = "" lang = st.session_state.get('lang', 'en') if highlight_sentence: # Normalize the sentence to be highlighted. normalized_highlight = re.sub(r'\s+', ' ', highlight_sentence).strip() # Use fuzzy matching to find the best match in the document. splitter = SentenceSplitter(language=lang) sentences_in_doc = splitter.split(text=full_text) if sentences_in_doc: best_match, score = process.extractOne(normalized_highlight, sentences_in_doc) start_index = full_text.find(best_match) if start_index != -1: end_index = start_index + len(best_match) # Create a context window around the matched sentence. context_window = 500 snippet_start = max(0, start_index - context_window) snippet_end = min(len(full_text), end_index + context_window) # Reconstruct the HTML with the highlighted sentence. before = html.escape(full_text[snippet_start:start_index]) highlight = html.escape(best_match) after = html.escape(full_text[end_index:snippet_end]) # Add ellipses if we're not showing the full text. start_ellipsis = "... " if snippet_start > 0 else "" end_ellipsis = " ..." if snippet_end < len(full_text) else "" highlighted_html = ( f"{start_ellipsis}{before}" f'{highlight}' f"{after}{end_ellipsis}" ) # If no highlight was applied, just show the full text. if not highlighted_html: highlighted_html = html.escape(full_text) st.markdown(f"""
{tr('source_label')}: {doc['source']} {tr('similarity_label')}: {doc['similarity']:.3f}
{highlighted_html.strip()}
""", unsafe_allow_html=True) else: # Give a helpful message if the index is missing. if not os.path.exists(INDEX_PATH) or not os.path.exists(MAPPING_PATH): st.warning(tr('influence_index_not_found_warning')) else: st.info(tr('no_influential_docs_found')) else: st.info(tr('run_analysis_for_influence_info')) # Show the feedback survey in the sidebar. #if 'all_attributions' in st.session_state: # display_attribution_feedback() if __name__ == "__main__": show_attribution_analysis()