ELIA / web_app.py
aaron0eidt's picture
Force dark mode theme for app
688668c
import streamlit as st
from streamlit_option_menu import option_menu
import os
import sys
import base64
from pathlib import Path
# Import the page modules.
from attribution_analysis.attribution_analysis_page import show_attribution_analysis
from function_vectors.function_vectors_page import show_function_vectors_page
from circuit_analysis.circuit_trace_page import show_circuit_trace_page
from utilities.welcome_page import show_welcome_page
from utilities.utils import set_seed
from utilities.localization import initialize_localization, tr, language_selector
from utilities.feedback_survey import get_next_participant_id
# Import functions with persisted cache to clear them when needed.
from attribution_analysis.attribution_analysis_page import (
get_influential_docs,
_cached_explain_heatmap as attr_explain_heatmap,
generate_all_attribution_analyses
)
from circuit_analysis.circuit_trace_page import explain_circuit_visualization
from function_vectors.function_vectors_page import (
_perform_analysis as fv_perform_analysis,
_explain_with_llm as fv_explain_llm
)
# Set the page configuration.
st.set_page_config(
page_title="LLM Analysis Suite",
page_icon="🧠",
layout="wide",
initial_sidebar_state="expanded"
)
# Set TOKENIZERS_PARALLELISM to false to avoid warnings.
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Suppress a harmless error on macOS.
os.environ["OBJC_DISABLE_INITIALIZE_FORK_SAFETY"] = "YES"
# Custom CSS for styling the app.
st.markdown("""
<style>
/* Hide the theme settings menu to enforce dark mode */
[data-testid="stToolbar"] {
visibility: hidden;
}
.main-header {
font-size: 3rem;
color: #2f3f70;
text-align: center;
margin-bottom: 2rem;
}
.stButton > button {
background-color: #2f3f70;
color: #f5f7fb;
border-radius: 20px;
border: none;
padding: 0.5rem 2rem;
font-weight: bold;
box-shadow: 0 10px 20px rgba(47, 63, 112, 0.25);
}
.stButton > button:hover {
background-color: #3a4c86;
color: #ffffff;
}
.stTextArea > div > div > textarea {
border-radius: 10px;
}
.attribution-info {
background-color: rgba(47, 63, 112, 0.82);
color: #f5f7fb;
padding: 1rem;
border-radius: 10px;
margin: 1rem 0;
border-left: 4px solid #dcae36;
}
</style>
""", unsafe_allow_html=True)
def main():
# Main function to run the app.
set_seed()
initialize_localization()
# The language selector is now on the welcome page.
# We don't need it in the sidebar of the main app.
# Check if the user has submitted the welcome form.
if "user_info" not in st.session_state or not st.session_state.user_info.get("form_submitted"):
show_welcome_page()
else:
# If the form is submitted, show the main application.
# Assign a participant ID if one doesn't exist.
if 'participant_id' not in st.session_state:
st.session_state.participant_id = get_next_participant_id()
# Initialize session state for feedback forms.
if 'attr_feedback_submitted' not in st.session_state:
st.session_state.attr_feedback_submitted = False
if 'fv_feedback_submitted' not in st.session_state:
st.session_state.fv_feedback_submitted = False
logo_path = Path(__file__).parent / "LOGO" / "Logo.png"
if logo_path.exists():
with open(logo_path, "rb") as logo_file:
logo_base64 = base64.b64encode(logo_file.read()).decode("utf-8")
st.markdown(
f"""
<div style="text-align: center; margin-bottom: 2rem;">
<img src="data:image/png;base64,{logo_base64}" alt="{tr('llm_analysis_suite')}" style="max-width: 320px; width: 60%; min-width: 200px;" />
</div>
""",
unsafe_allow_html=True
)
else:
st.markdown(f"<h1 class='main-header'>{tr('llm_analysis_suite')}</h1>", unsafe_allow_html=True)
with st.sidebar:
selected_page = option_menu(
menu_title=tr('main_menu'),
options=[tr('attribution_analysis'), tr('function_vectors'), tr('circuit_tracing')],
icons=['search', 'cpu', 'diagram-3'],
menu_icon='cast',
default_index=0
)
if selected_page == tr('attribution_analysis'):
show_attribution_analysis()
elif selected_page == tr('function_vectors'):
show_function_vectors_page()
elif selected_page == tr('circuit_tracing'):
show_circuit_trace_page()
if __name__ == "__main__":
main()