Spaces:
Running
Running
updated to match previous app version
Browse files- app_with_LLM.py +541 -170
app_with_LLM.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
"""
|
| 2 |
-
File:
|
| 3 |
Description: Streamlit app for advanced topic modeling on Innerspeech dataset
|
| 4 |
-
with BERTopic, UMAP, HDBSCAN.
|
|
|
|
| 5 |
Last Modified: 08/12/2025
|
| 6 |
"""
|
| 7 |
|
|
@@ -18,31 +19,12 @@ import re
|
|
| 18 |
import os
|
| 19 |
import nltk
|
| 20 |
import json
|
| 21 |
-
import logging
|
| 22 |
|
| 23 |
-
#
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
#
|
| 28 |
-
# set ENABLE_LLM to "True" in Hugging Face Space Settings to activate
|
| 29 |
-
ENABLE_LLM = os.getenv("ENABLE_LLM", "False").lower() in ("true", "1", "yes")
|
| 30 |
-
|
| 31 |
-
# Conditional Imports for Heavy LLM Libraries
|
| 32 |
-
LLM_AVAILABLE = False
|
| 33 |
-
if ENABLE_LLM:
|
| 34 |
-
try:
|
| 35 |
-
from llama_cpp import Llama
|
| 36 |
-
from bertopic.representation import LlamaCPP
|
| 37 |
-
from huggingface_hub import hf_hub_download
|
| 38 |
-
LLM_AVAILABLE = True
|
| 39 |
-
logger.info("LLM modules imported successfully.")
|
| 40 |
-
except ImportError as e:
|
| 41 |
-
logger.warning(f"ENABLE_LLM is True, but imports failed: {e}. Falling back to Lite mode.")
|
| 42 |
-
ENABLE_LLM = False
|
| 43 |
-
|
| 44 |
-
# Standard Imports
|
| 45 |
-
from mosaic.path_utils import CFG, raw_path, proc_path, eval_path, project_root # type: ignore
|
| 46 |
|
| 47 |
# BERTopic stack
|
| 48 |
from bertopic import BERTopic
|
|
@@ -57,6 +39,7 @@ from hdbscan import HDBSCAN
|
|
| 57 |
import datamapplot
|
| 58 |
import matplotlib.pyplot as plt
|
| 59 |
|
|
|
|
| 60 |
# =====================================================================
|
| 61 |
# NLTK setup
|
| 62 |
# =====================================================================
|
|
@@ -65,6 +48,7 @@ NLTK_DATA_DIR = "/usr/local/share/nltk_data"
|
|
| 65 |
if NLTK_DATA_DIR not in nltk.data.path:
|
| 66 |
nltk.data.path.append(NLTK_DATA_DIR)
|
| 67 |
|
|
|
|
| 68 |
for resource in ("punkt_tab", "punkt"):
|
| 69 |
try:
|
| 70 |
nltk.data.find(f"tokenizers/{resource}")
|
|
@@ -81,10 +65,12 @@ for resource in ("punkt_tab", "punkt"):
|
|
| 81 |
try:
|
| 82 |
from mosaic.path_utils import CFG, raw_path, proc_path, eval_path, project_root # type: ignore
|
| 83 |
except Exception:
|
|
|
|
| 84 |
def _env(key: str, default: str) -> Path:
|
| 85 |
val = os.getenv(key, default)
|
| 86 |
return Path(val).expanduser().resolve()
|
| 87 |
|
|
|
|
| 88 |
_DATA_ROOT = _env("MOSAIC_DATA", str(Path(__file__).parent / "data"))
|
| 89 |
_BOX_ROOT = _env("MOSAIC_BOX", str(Path(__file__).parent / "data" / "raw"))
|
| 90 |
_EVAL_ROOT = _env("MOSAIC_EVAL", str(Path(__file__).parent / "eval"))
|
|
@@ -107,43 +93,63 @@ except Exception:
|
|
| 107 |
def eval_path(*parts: str) -> Path:
|
| 108 |
return _EVAL_ROOT.joinpath(*parts)
|
| 109 |
|
|
|
|
| 110 |
# =====================================================================
|
| 111 |
# 0. Constants & Helper Functions
|
| 112 |
# =====================================================================
|
| 113 |
|
|
|
|
| 114 |
def _slugify(s: str) -> str:
|
| 115 |
s = s.strip()
|
| 116 |
s = re.sub(r"[^A-Za-z0-9._-]+", "_", s)
|
| 117 |
return s or "DATASET"
|
| 118 |
|
| 119 |
def _cleanup_old_cache(current_slug: str):
|
|
|
|
| 120 |
if not CACHE_DIR.exists():
|
| 121 |
return
|
|
|
|
| 122 |
removed_count = 0
|
|
|
|
| 123 |
for p in CACHE_DIR.glob("precomputed_*.npy"):
|
|
|
|
| 124 |
if current_slug not in p.name:
|
| 125 |
try:
|
| 126 |
-
p.unlink()
|
| 127 |
removed_count += 1
|
| 128 |
except Exception as e:
|
| 129 |
print(f"Error deleting {p.name}: {e}")
|
|
|
|
| 130 |
if removed_count > 0:
|
| 131 |
print(f"Auto-cleanup: Removed {removed_count} old cache files.")
|
| 132 |
|
| 133 |
ACCEPTABLE_TEXT_COLUMNS = [
|
| 134 |
-
"reflection_answer_english",
|
|
|
|
|
|
|
|
|
|
| 135 |
]
|
| 136 |
|
|
|
|
| 137 |
def _pick_text_column(df: pd.DataFrame) -> str | None:
|
|
|
|
| 138 |
for col in ACCEPTABLE_TEXT_COLUMNS:
|
| 139 |
if col in df.columns:
|
| 140 |
return col
|
| 141 |
return None
|
| 142 |
|
|
|
|
| 143 |
def _list_text_columns(df: pd.DataFrame) -> list[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
return list(df.columns)
|
| 145 |
|
|
|
|
|
|
|
| 146 |
def _set_from_env_or_secrets(key: str):
|
|
|
|
| 147 |
if os.getenv(key):
|
| 148 |
return
|
| 149 |
try:
|
|
@@ -153,44 +159,58 @@ def _set_from_env_or_secrets(key: str):
|
|
| 153 |
if val:
|
| 154 |
os.environ[key] = str(val)
|
| 155 |
|
|
|
|
|
|
|
| 156 |
for _k in ("MOSAIC_DATA", "MOSAIC_BOX"):
|
| 157 |
_set_from_env_or_secrets(_k)
|
| 158 |
|
|
|
|
| 159 |
@st.cache_data
|
| 160 |
def count_clean_reports(csv_path: str, text_col: str | None = None) -> int:
|
|
|
|
| 161 |
df = pd.read_csv(csv_path)
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
if col is None:
|
| 164 |
return 0
|
|
|
|
| 165 |
if col != "reflection_answer_english":
|
| 166 |
df = df.rename(columns={col: "reflection_answer_english"})
|
|
|
|
| 167 |
df.dropna(subset=["reflection_answer_english"], inplace=True)
|
| 168 |
df["reflection_answer_english"] = df["reflection_answer_english"].astype(str)
|
| 169 |
df = df[df["reflection_answer_english"].str.strip() != ""]
|
| 170 |
return len(df)
|
| 171 |
|
|
|
|
| 172 |
# =====================================================================
|
| 173 |
# 1. Streamlit app setup
|
| 174 |
# =====================================================================
|
| 175 |
|
| 176 |
-
st.set_page_config(page_title="MOSAIC Dashboard", layout="wide")
|
| 177 |
-
st.title(
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
|
| 182 |
st.markdown(
|
| 183 |
"""
|
| 184 |
_If you use this tool in your research, please cite the following paper:_\n
|
| 185 |
-
**Beauté, R., et al. (2025).** **Mapping of Subjective Accounts into Interpreted Clusters (MOSAIC)** https://arxiv.org/abs/2502.18318
|
| 186 |
"""
|
| 187 |
)
|
| 188 |
|
| 189 |
# =====================================================================
|
| 190 |
-
# 2. Dataset paths
|
| 191 |
# =====================================================================
|
| 192 |
|
| 193 |
-
ds_input = st.sidebar.text_input(
|
|
|
|
|
|
|
| 194 |
DATASET_DIR = _slugify(ds_input).upper()
|
| 195 |
|
| 196 |
RAW_DIR = raw_path(DATASET_DIR)
|
|
@@ -202,36 +222,50 @@ PROC_DIR.mkdir(parents=True, exist_ok=True)
|
|
| 202 |
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
| 203 |
EVAL_DIR.mkdir(parents=True, exist_ok=True)
|
| 204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
def _list_server_csvs(proc_dir: Path) -> list[str]:
|
| 206 |
return [str(p) for p in sorted(proc_dir.glob("*.csv"))]
|
| 207 |
|
|
|
|
|
|
|
| 208 |
HISTORY_FILE = str(PROC_DIR / "run_history.json")
|
| 209 |
|
| 210 |
# =====================================================================
|
| 211 |
# 3. Embedding & LLM loaders
|
| 212 |
# =====================================================================
|
| 213 |
|
|
|
|
| 214 |
@st.cache_resource
|
| 215 |
def load_embedding_model(model_name):
|
| 216 |
st.info(f"Loading embedding model '{model_name}'...")
|
| 217 |
return SentenceTransformer(model_name)
|
| 218 |
|
|
|
|
| 219 |
@st.cache_resource
|
| 220 |
def load_llm_model():
|
| 221 |
-
"""Loads LlamaCPP model
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
st.info("Loading Llama-3-8B-Instruct (Quantized)...")
|
| 226 |
try:
|
| 227 |
-
model_repo = "NousResearch/Meta-Llama-3-8B-Instruct-GGUF"
|
| 228 |
-
model_file = "Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"
|
| 229 |
model_path = hf_hub_download(repo_id=model_repo, filename=model_file)
|
| 230 |
-
# n_gpu_layers=-1 attempts to offload all to GPU
|
| 231 |
return Llama(model_path=model_path, n_gpu_layers=-1, n_ctx=8192, stop=["Q:", "\n"], verbose=False)
|
| 232 |
except Exception as e:
|
| 233 |
st.error(f"Failed to load LLM: {e}")
|
| 234 |
return None
|
|
|
|
| 235 |
|
| 236 |
@st.cache_data
|
| 237 |
def load_precomputed_data(docs_file, embeddings_file):
|
|
@@ -239,63 +273,86 @@ def load_precomputed_data(docs_file, embeddings_file):
|
|
| 239 |
emb = np.load(embeddings_file, allow_pickle=True)
|
| 240 |
return docs, emb
|
| 241 |
|
|
|
|
| 242 |
# =====================================================================
|
| 243 |
# 4. Topic modeling function
|
| 244 |
# =====================================================================
|
| 245 |
|
|
|
|
| 246 |
def get_config_hash(cfg):
|
| 247 |
return json.dumps(cfg, sort_keys=True)
|
| 248 |
|
|
|
|
| 249 |
@st.cache_data
|
| 250 |
-
def perform_topic_modeling(_docs, _embeddings, config_hash
|
| 251 |
"""Fit BERTopic using cached result."""
|
|
|
|
| 252 |
_docs = list(_docs)
|
| 253 |
_embeddings = np.asarray(_embeddings)
|
| 254 |
if _embeddings.dtype == object or _embeddings.ndim != 2:
|
| 255 |
try:
|
| 256 |
_embeddings = np.vstack(_embeddings)
|
| 257 |
except Exception:
|
| 258 |
-
st.error(
|
|
|
|
|
|
|
|
|
|
| 259 |
st.stop()
|
| 260 |
_embeddings = np.ascontiguousarray(_embeddings, dtype=np.float32)
|
| 261 |
|
| 262 |
if _embeddings.shape[0] != len(_docs):
|
| 263 |
-
st.error(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
st.stop()
|
| 265 |
|
| 266 |
config = json.loads(config_hash)
|
| 267 |
|
| 268 |
if "ngram_range" in config["vectorizer_params"]:
|
| 269 |
-
config["vectorizer_params"]["ngram_range"] = tuple(
|
|
|
|
|
|
|
| 270 |
|
| 271 |
-
# --- Representation
|
|
|
|
| 272 |
rep_model = None
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
prompt = """Q:
|
| 277 |
You are an expert in micro-phenomenology. The following documents are reflections from participants about their experience.
|
| 278 |
I have a topic that contains the following documents:
|
| 279 |
[DOCUMENTS]
|
| 280 |
The topic is described by the following keywords: '[KEYWORDS]'.
|
| 281 |
Based on the above information, give a short, informative label (5–10 words).
|
| 282 |
A:"""
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
print("LLM requested but load failed; falling back to default representation.")
|
| 288 |
|
| 289 |
umap_model = UMAP(random_state=42, metric="cosine", **config["umap_params"])
|
| 290 |
-
hdbscan_model = HDBSCAN(
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
topic_model = BERTopic(
|
| 295 |
umap_model=umap_model,
|
| 296 |
hdbscan_model=hdbscan_model,
|
| 297 |
vectorizer_model=vectorizer_model,
|
| 298 |
-
representation_model=rep_model,
|
| 299 |
top_n_words=config["bt_params"]["top_n_words"],
|
| 300 |
nr_topics=nr_topics_val,
|
| 301 |
verbose=False,
|
|
@@ -306,34 +363,60 @@ A:"""
|
|
| 306 |
|
| 307 |
outlier_pct = 0
|
| 308 |
if -1 in info.Topic.values:
|
| 309 |
-
outlier_pct = (
|
|
|
|
|
|
|
| 310 |
|
| 311 |
-
#
|
| 312 |
-
if
|
| 313 |
-
# Extract LLM labels if available
|
| 314 |
raw_labels = [label[0][0] for label in topic_model.get_topics(full=True)["LLM"].values()]
|
| 315 |
cleaned_labels = [lbl.split(":")[-1].strip().strip('"').strip(".") for lbl in raw_labels]
|
| 316 |
final_labels = [lbl if lbl else "Unlabelled" for lbl in cleaned_labels]
|
| 317 |
-
# Map back to docs
|
| 318 |
all_labels = [final_labels[topic + topic_model._outliers] if topic != -1 else "Unlabelled" for topic in topics]
|
| 319 |
else:
|
| 320 |
-
#
|
| 321 |
-
|
|
|
|
| 322 |
all_labels = [name_map[topic] for topic in topics]
|
|
|
|
| 323 |
|
| 324 |
-
reduced = UMAP(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
|
| 326 |
return topic_model, reduced, all_labels, len(info) - 1, outlier_pct
|
| 327 |
|
|
|
|
| 328 |
# =====================================================================
|
| 329 |
# 5. CSV → documents → embeddings pipeline
|
| 330 |
# =====================================================================
|
| 331 |
|
| 332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
st.info(f"Reading and preparing CSV: {csv_path}")
|
| 334 |
df = pd.read_csv(csv_path)
|
| 335 |
|
| 336 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
if col is None:
|
| 338 |
st.error("CSV must contain at least one text column.")
|
| 339 |
return
|
|
@@ -346,6 +429,9 @@ def generate_and_save_embeddings(csv_path, docs_file, emb_file, selected_embeddi
|
|
| 346 |
df = df[df["reflection_answer_english"].str.strip() != ""]
|
| 347 |
reports = df["reflection_answer_english"].tolist()
|
| 348 |
|
|
|
|
|
|
|
|
|
|
| 349 |
if split_sentences:
|
| 350 |
try:
|
| 351 |
sentences = [s for r in reports for s in nltk.sent_tokenize(r)]
|
|
@@ -359,13 +445,28 @@ def generate_and_save_embeddings(csv_path, docs_file, emb_file, selected_embeddi
|
|
| 359 |
np.save(docs_file, np.array(docs, dtype=object))
|
| 360 |
st.success(f"Prepared {len(docs)} documents")
|
| 361 |
|
| 362 |
-
|
| 363 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
|
| 365 |
-
|
| 366 |
-
batch_size = 64 if device == "CPU" else 32
|
| 367 |
|
| 368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
embeddings = np.asarray(embeddings, dtype=np.float32)
|
| 370 |
np.save(emb_file, embeddings)
|
| 371 |
|
|
@@ -373,43 +474,84 @@ def generate_and_save_embeddings(csv_path, docs_file, emb_file, selected_embeddi
|
|
| 373 |
st.balloons()
|
| 374 |
st.rerun()
|
| 375 |
|
|
|
|
| 376 |
# =====================================================================
|
| 377 |
# 6. Sidebar — dataset, upload, parameters
|
| 378 |
# =====================================================================
|
| 379 |
|
| 380 |
st.sidebar.header("Data Input Method")
|
| 381 |
|
| 382 |
-
source = st.sidebar.radio(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
|
| 384 |
uploaded_csv_path = None
|
| 385 |
-
CSV_PATH = None
|
| 386 |
|
| 387 |
if source == "Use preprocessed CSV on server":
|
| 388 |
available = _list_server_csvs(PROC_DIR)
|
| 389 |
if not available:
|
| 390 |
-
st.info(
|
|
|
|
|
|
|
| 391 |
st.stop()
|
| 392 |
-
selected_csv = st.sidebar.selectbox(
|
|
|
|
|
|
|
| 393 |
CSV_PATH = selected_csv
|
| 394 |
else:
|
| 395 |
-
up = st.sidebar.file_uploader(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
if up is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
encodings_to_try = ['utf-8', 'mac_roman', 'cp1252', 'ISO-8859-1']
|
|
|
|
| 398 |
tmp_df = None
|
|
|
|
|
|
|
| 399 |
for encoding in encodings_to_try:
|
| 400 |
try:
|
| 401 |
-
up.seek(0)
|
| 402 |
tmp_df = pd.read_csv(up, encoding=encoding)
|
| 403 |
-
|
|
|
|
| 404 |
except UnicodeDecodeError:
|
| 405 |
-
continue
|
| 406 |
-
|
| 407 |
-
|
|
|
|
| 408 |
st.stop()
|
| 409 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
safe_filename = _slugify(os.path.splitext(up.name)[0])
|
| 411 |
_cleanup_old_cache(safe_filename)
|
| 412 |
uploaded_csv_path = str((PROC_DIR / f"{safe_filename}.csv").resolve())
|
|
|
|
| 413 |
tmp_df.to_csv(uploaded_csv_path, index=False)
|
| 414 |
st.success(f"Uploaded CSV saved to {uploaded_csv_path}")
|
| 415 |
CSV_PATH = uploaded_csv_path
|
|
@@ -420,115 +562,219 @@ else:
|
|
| 420 |
if CSV_PATH is None:
|
| 421 |
st.stop()
|
| 422 |
|
|
|
|
| 423 |
# Text column selection
|
|
|
|
|
|
|
|
|
|
| 424 |
@st.cache_data
|
| 425 |
def get_text_columns(csv_path: str) -> list[str]:
|
| 426 |
df_sample = pd.read_csv(csv_path, nrows=2000)
|
| 427 |
return _list_text_columns(df_sample)
|
| 428 |
|
| 429 |
text_columns = get_text_columns(CSV_PATH)
|
|
|
|
| 430 |
if not text_columns:
|
| 431 |
-
st.error(
|
|
|
|
|
|
|
| 432 |
st.stop()
|
| 433 |
|
| 434 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
try:
|
| 436 |
df_sample = pd.read_csv(CSV_PATH, nrows=2000)
|
| 437 |
preferred = _pick_text_column(df_sample)
|
| 438 |
-
except Exception:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
|
| 440 |
-
|
| 441 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
|
| 443 |
-
# Data Granularity
|
| 444 |
st.sidebar.subheader("Data Granularity & Subsampling")
|
| 445 |
-
|
|
|
|
|
|
|
|
|
|
| 446 |
granularity_label = "sentences" if selected_granularity else "reports"
|
|
|
|
| 447 |
subsample_perc = st.sidebar.slider("Data sampling (%)", 10, 100, 100, 5)
|
| 448 |
|
| 449 |
st.sidebar.markdown("---")
|
| 450 |
|
| 451 |
-
#
|
|
|
|
|
|
|
|
|
|
| 452 |
st.sidebar.header("Model Selection")
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
"
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
|
| 461 |
# =====================================================================
|
| 462 |
-
# 7. Precompute filenames
|
| 463 |
# =====================================================================
|
| 464 |
|
|
|
|
| 465 |
def get_precomputed_filenames(csv_path, model_name, split_sentences, text_col):
|
| 466 |
base = os.path.splitext(os.path.basename(csv_path))[0]
|
| 467 |
safe_model = re.sub(r"[^a-zA-Z0-9_-]", "_", model_name)
|
| 468 |
suf = "sentences" if split_sentences else "reports"
|
| 469 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
return (
|
| 471 |
str(CACHE_DIR / f"precomputed_{base}{col_suffix}_{suf}_docs.npy"),
|
| 472 |
-
str(
|
|
|
|
|
|
|
|
|
|
| 473 |
)
|
| 474 |
|
| 475 |
-
DOCS_FILE, EMBEDDINGS_FILE = get_precomputed_filenames(CSV_PATH, selected_embedding_model, selected_granularity, selected_text_column)
|
| 476 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
st.sidebar.markdown("### Cache")
|
| 478 |
-
if st.sidebar.button(
|
| 479 |
-
for
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 486 |
|
| 487 |
st.sidebar.markdown("---")
|
| 488 |
|
| 489 |
# =====================================================================
|
| 490 |
-
# 8. Run Analysis
|
| 491 |
# =====================================================================
|
| 492 |
|
| 493 |
if not os.path.exists(EMBEDDINGS_FILE):
|
| 494 |
-
st.warning(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
if st.button("Prepare Data for This Configuration"):
|
| 496 |
-
generate_and_save_embeddings(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
else:
|
|
|
|
| 498 |
docs, embeddings = load_precomputed_data(DOCS_FILE, EMBEDDINGS_FILE)
|
|
|
|
| 499 |
embeddings = np.asarray(embeddings)
|
| 500 |
if embeddings.dtype == object or embeddings.ndim != 2:
|
| 501 |
try:
|
| 502 |
embeddings = np.vstack(embeddings).astype(np.float32)
|
| 503 |
except Exception:
|
| 504 |
-
st.error(
|
|
|
|
|
|
|
| 505 |
st.stop()
|
| 506 |
|
| 507 |
if subsample_perc < 100:
|
| 508 |
n = int(len(docs) * (subsample_perc / 100))
|
| 509 |
idx = np.random.choice(len(docs), size=n, replace=False)
|
| 510 |
docs = [docs[i] for i in idx]
|
| 511 |
-
embeddings = embeddings[idx, :]
|
| 512 |
-
st.warning(
|
|
|
|
|
|
|
| 513 |
|
|
|
|
| 514 |
st.subheader("Dataset summary")
|
| 515 |
n_reports = count_clean_reports(CSV_PATH, selected_text_column)
|
| 516 |
-
|
|
|
|
| 517 |
|
| 518 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
st.sidebar.header("Model Parameters")
|
|
|
|
| 520 |
use_vectorizer = st.sidebar.checkbox("Use CountVectorizer", value=True)
|
|
|
|
| 521 |
with st.sidebar.expander("Vectorizer"):
|
| 522 |
-
ng_min
|
|
|
|
| 523 |
min_df = st.slider("Min Doc Freq", 1, 50, 1)
|
| 524 |
-
stopwords = st.select_slider(
|
|
|
|
|
|
|
|
|
|
| 525 |
with st.sidebar.expander("UMAP"):
|
| 526 |
um_n = st.slider("n_neighbors", 2, 50, 15)
|
| 527 |
um_c = st.slider("n_components", 2, 20, 5)
|
| 528 |
um_d = st.slider("min_dist", 0.0, 1.0, 0.0)
|
|
|
|
| 529 |
with st.sidebar.expander("HDBSCAN"):
|
| 530 |
hs = st.slider("min_cluster_size", 5, 100, 10)
|
| 531 |
hm = st.slider("min_samples", 2, 100, 5)
|
|
|
|
| 532 |
with st.sidebar.expander("BERTopic"):
|
| 533 |
nr_topics = st.text_input("nr_topics", value="auto")
|
| 534 |
top_n_words = st.slider("top_n_words", 5, 25, 10)
|
|
@@ -538,12 +784,25 @@ else:
|
|
| 538 |
"granularity": granularity_label,
|
| 539 |
"subsample_percent": subsample_perc,
|
| 540 |
"use_vectorizer": use_vectorizer,
|
| 541 |
-
"vectorizer_params": {
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
"text_column": selected_text_column,
|
| 546 |
-
"llm_enabled": ENABLE_LLM
|
| 547 |
}
|
| 548 |
|
| 549 |
run_button = st.sidebar.button("Run Analysis", type="primary")
|
|
@@ -555,13 +814,16 @@ else:
|
|
| 555 |
|
| 556 |
def load_history():
|
| 557 |
path = HISTORY_FILE
|
| 558 |
-
if not os.path.exists(path):
|
|
|
|
| 559 |
try:
|
| 560 |
data = json.load(open(path))
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
|
|
|
|
|
|
| 565 |
|
| 566 |
def save_history(h):
|
| 567 |
json.dump(h, open(HISTORY_FILE, "w"), indent=2)
|
|
@@ -570,10 +832,29 @@ else:
|
|
| 570 |
st.session_state.history = load_history()
|
| 571 |
|
| 572 |
if run_button:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
with st.spinner("Performing topic modeling..."):
|
| 574 |
-
# Pass the global ENABLE_LLM flag to the cached function
|
| 575 |
model, reduced, labels, n_topics, outlier_pct = perform_topic_modeling(
|
| 576 |
-
docs, embeddings, get_config_hash(current_config)
|
| 577 |
)
|
| 578 |
st.session_state.latest_results = (model, reduced, labels)
|
| 579 |
|
|
@@ -582,56 +863,143 @@ else:
|
|
| 582 |
"config": current_config,
|
| 583 |
"num_topics": n_topics,
|
| 584 |
"outlier_pct": f"{outlier_pct:.2f}%",
|
| 585 |
-
"llm_labels": [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 586 |
}
|
| 587 |
st.session_state.history.insert(0, entry)
|
| 588 |
save_history(st.session_state.history)
|
| 589 |
st.rerun()
|
| 590 |
|
|
|
|
| 591 |
with main_tab:
|
| 592 |
if "latest_results" in st.session_state:
|
| 593 |
tm, reduced, labs = st.session_state.latest_results
|
|
|
|
| 594 |
st.subheader("Experiential Topics Visualisation")
|
| 595 |
fig, _ = datamapplot.create_plot(reduced, labs)
|
| 596 |
st.pyplot(fig)
|
|
|
|
| 597 |
st.subheader("Topic Info")
|
| 598 |
st.dataframe(tm.get_topic_info())
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
full_reps = tm.get_topics(full=True)
|
| 603 |
llm_reps = full_reps.get("LLM", {})
|
| 604 |
-
|
| 605 |
-
# Determine how to name topics based on what mode we ran
|
| 606 |
llm_names = {}
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 615 |
|
| 616 |
doc_info = tm.get_document_info(docs)[["Document", "Topic"]]
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 623 |
base = os.path.splitext(os.path.basename(CSV_PATH))[0]
|
| 624 |
gran = "sentences" if selected_granularity else "reports"
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 634 |
|
|
|
|
| 635 |
with history_tab:
|
| 636 |
st.subheader("Run History")
|
| 637 |
if not st.session_state.history:
|
|
@@ -640,7 +1008,10 @@ else:
|
|
| 640 |
for i, entry in enumerate(st.session_state.history):
|
| 641 |
with st.expander(f"Run {i+1} — {entry['timestamp']}"):
|
| 642 |
st.write(f"**Topics:** {entry['num_topics']}")
|
| 643 |
-
st.write(
|
|
|
|
|
|
|
| 644 |
st.write("**Topic Labels:**")
|
| 645 |
-
st.write(entry
|
| 646 |
-
st.
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
File: app_with_LLM.py
|
| 3 |
Description: Streamlit app for advanced topic modeling on Innerspeech dataset
|
| 4 |
+
with BERTopic, UMAP, HDBSCAN.
|
| 5 |
+
**PRO VERSION: LLM (LlamaCPP) Enabled**
|
| 6 |
Last Modified: 08/12/2025
|
| 7 |
"""
|
| 8 |
|
|
|
|
| 19 |
import os
|
| 20 |
import nltk
|
| 21 |
import json
|
|
|
|
| 22 |
|
| 23 |
+
# --- LLM Specific Imports (Added for Pro Version) ---
|
| 24 |
+
from llama_cpp import Llama
|
| 25 |
+
from bertopic.representation import LlamaCPP
|
| 26 |
+
from huggingface_hub import hf_hub_download
|
| 27 |
+
# ----------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
# BERTopic stack
|
| 30 |
from bertopic import BERTopic
|
|
|
|
| 39 |
import datamapplot
|
| 40 |
import matplotlib.pyplot as plt
|
| 41 |
|
| 42 |
+
|
| 43 |
# =====================================================================
|
| 44 |
# NLTK setup
|
| 45 |
# =====================================================================
|
|
|
|
| 48 |
if NLTK_DATA_DIR not in nltk.data.path:
|
| 49 |
nltk.data.path.append(NLTK_DATA_DIR)
|
| 50 |
|
| 51 |
+
# Try to ensure both punkt_tab (new NLTK) and punkt (old NLTK) are available
|
| 52 |
for resource in ("punkt_tab", "punkt"):
|
| 53 |
try:
|
| 54 |
nltk.data.find(f"tokenizers/{resource}")
|
|
|
|
| 65 |
try:
|
| 66 |
from mosaic.path_utils import CFG, raw_path, proc_path, eval_path, project_root # type: ignore
|
| 67 |
except Exception:
|
| 68 |
+
# Minimal stand-in so the app works anywhere (Streamlit Cloud, local without MOSAIC, etc.)
|
| 69 |
def _env(key: str, default: str) -> Path:
|
| 70 |
val = os.getenv(key, default)
|
| 71 |
return Path(val).expanduser().resolve()
|
| 72 |
|
| 73 |
+
# Defaults: app-local data/ eval/ that are safe on Cloud
|
| 74 |
_DATA_ROOT = _env("MOSAIC_DATA", str(Path(__file__).parent / "data"))
|
| 75 |
_BOX_ROOT = _env("MOSAIC_BOX", str(Path(__file__).parent / "data" / "raw"))
|
| 76 |
_EVAL_ROOT = _env("MOSAIC_EVAL", str(Path(__file__).parent / "eval"))
|
|
|
|
| 93 |
def eval_path(*parts: str) -> Path:
|
| 94 |
return _EVAL_ROOT.joinpath(*parts)
|
| 95 |
|
| 96 |
+
|
| 97 |
# =====================================================================
|
| 98 |
# 0. Constants & Helper Functions
|
| 99 |
# =====================================================================
|
| 100 |
|
| 101 |
+
|
| 102 |
def _slugify(s: str) -> str:
|
| 103 |
s = s.strip()
|
| 104 |
s = re.sub(r"[^A-Za-z0-9._-]+", "_", s)
|
| 105 |
return s or "DATASET"
|
| 106 |
|
| 107 |
def _cleanup_old_cache(current_slug: str):
|
| 108 |
+
"""Deletes precomputed .npy files that do not match the current dataset slug."""
|
| 109 |
if not CACHE_DIR.exists():
|
| 110 |
return
|
| 111 |
+
|
| 112 |
removed_count = 0
|
| 113 |
+
# Iterate over all precomputed files
|
| 114 |
for p in CACHE_DIR.glob("precomputed_*.npy"):
|
| 115 |
+
# If the file belongs to a different dataset (doesn't contain the new slug)
|
| 116 |
if current_slug not in p.name:
|
| 117 |
try:
|
| 118 |
+
p.unlink() # Delete file
|
| 119 |
removed_count += 1
|
| 120 |
except Exception as e:
|
| 121 |
print(f"Error deleting {p.name}: {e}")
|
| 122 |
+
|
| 123 |
if removed_count > 0:
|
| 124 |
print(f"Auto-cleanup: Removed {removed_count} old cache files.")
|
| 125 |
|
| 126 |
ACCEPTABLE_TEXT_COLUMNS = [
|
| 127 |
+
"reflection_answer_english",
|
| 128 |
+
"reflection_answer",
|
| 129 |
+
"text",
|
| 130 |
+
"report",
|
| 131 |
]
|
| 132 |
|
| 133 |
+
|
| 134 |
def _pick_text_column(df: pd.DataFrame) -> str | None:
|
| 135 |
+
"""Return the first matching *preferred* text column name if present."""
|
| 136 |
for col in ACCEPTABLE_TEXT_COLUMNS:
|
| 137 |
if col in df.columns:
|
| 138 |
return col
|
| 139 |
return None
|
| 140 |
|
| 141 |
+
|
| 142 |
def _list_text_columns(df: pd.DataFrame) -> list[str]:
|
| 143 |
+
"""
|
| 144 |
+
Return all columns; we’ll cast the chosen one to string later.
|
| 145 |
+
This makes the selector work with any column name / dtype.
|
| 146 |
+
"""
|
| 147 |
return list(df.columns)
|
| 148 |
|
| 149 |
+
|
| 150 |
+
|
| 151 |
def _set_from_env_or_secrets(key: str):
|
| 152 |
+
"""Allow hosting: value can come from environment or from Streamlit secrets."""
|
| 153 |
if os.getenv(key):
|
| 154 |
return
|
| 155 |
try:
|
|
|
|
| 159 |
if val:
|
| 160 |
os.environ[key] = str(val)
|
| 161 |
|
| 162 |
+
|
| 163 |
+
# Enable both MOSAIC_DATA and MOSAIC_BOX automatically
|
| 164 |
for _k in ("MOSAIC_DATA", "MOSAIC_BOX"):
|
| 165 |
_set_from_env_or_secrets(_k)
|
| 166 |
|
| 167 |
+
|
| 168 |
@st.cache_data
|
| 169 |
def count_clean_reports(csv_path: str, text_col: str | None = None) -> int:
|
| 170 |
+
"""Count non-empty reports in the chosen text column."""
|
| 171 |
df = pd.read_csv(csv_path)
|
| 172 |
+
|
| 173 |
+
if text_col is not None and text_col in df.columns:
|
| 174 |
+
col = text_col
|
| 175 |
+
else:
|
| 176 |
+
col = _pick_text_column(df)
|
| 177 |
+
|
| 178 |
if col is None:
|
| 179 |
return 0
|
| 180 |
+
|
| 181 |
if col != "reflection_answer_english":
|
| 182 |
df = df.rename(columns={col: "reflection_answer_english"})
|
| 183 |
+
|
| 184 |
df.dropna(subset=["reflection_answer_english"], inplace=True)
|
| 185 |
df["reflection_answer_english"] = df["reflection_answer_english"].astype(str)
|
| 186 |
df = df[df["reflection_answer_english"].str.strip() != ""]
|
| 187 |
return len(df)
|
| 188 |
|
| 189 |
+
|
| 190 |
# =====================================================================
|
| 191 |
# 1. Streamlit app setup
|
| 192 |
# =====================================================================
|
| 193 |
|
| 194 |
+
st.set_page_config(page_title="MOSAIC Dashboard (Pro)", layout="wide")
|
| 195 |
+
st.title(
|
| 196 |
+
"Mapping of Subjective Accounts into Interpreted Clusters (MOSAIC): "
|
| 197 |
+
"Topic Modelling Dashboard (Pro Version)"
|
| 198 |
+
)
|
| 199 |
|
| 200 |
st.markdown(
|
| 201 |
"""
|
| 202 |
_If you use this tool in your research, please cite the following paper:_\n
|
| 203 |
+
**Beauté, R., et al. (2025).** **Mapping of Subjective Accounts into Interpreted Clusters (MOSAIC): Topic Modelling and LLM applied to Stroboscopic Phenomenology** https://arxiv.org/abs/2502.18318
|
| 204 |
"""
|
| 205 |
)
|
| 206 |
|
| 207 |
# =====================================================================
|
| 208 |
+
# 2. Dataset paths (using MOSAIC structure)
|
| 209 |
# =====================================================================
|
| 210 |
|
| 211 |
+
ds_input = st.sidebar.text_input(
|
| 212 |
+
"Project/Dataset name", value="MOSAIC", key="dataset_name_input"
|
| 213 |
+
)
|
| 214 |
DATASET_DIR = _slugify(ds_input).upper()
|
| 215 |
|
| 216 |
RAW_DIR = raw_path(DATASET_DIR)
|
|
|
|
| 222 |
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
| 223 |
EVAL_DIR.mkdir(parents=True, exist_ok=True)
|
| 224 |
|
| 225 |
+
with st.sidebar.expander("About the dataset name", expanded=False):
|
| 226 |
+
st.markdown(
|
| 227 |
+
f"""
|
| 228 |
+
- The name above is converted to **UPPER CASE** and used as a folder name.
|
| 229 |
+
- If the folder doesn’t exist, it will be **created**:
|
| 230 |
+
- Preprocessed CSVs: `{PROC_DIR}`
|
| 231 |
+
- Exports (results): `{EVAL_DIR}`
|
| 232 |
+
- If you choose **Use preprocessed CSV on server**, I’ll list CSVs in `{PROC_DIR}`.
|
| 233 |
+
- If you **upload** a CSV, it will be saved to `{PROC_DIR}/uploaded.csv`.
|
| 234 |
+
""".strip()
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
def _list_server_csvs(proc_dir: Path) -> list[str]:
|
| 239 |
return [str(p) for p in sorted(proc_dir.glob("*.csv"))]
|
| 240 |
|
| 241 |
+
|
| 242 |
+
DATASETS = None # keep name for clarity; we’ll fill it when rendering the sidebar
|
| 243 |
HISTORY_FILE = str(PROC_DIR / "run_history.json")
|
| 244 |
|
| 245 |
# =====================================================================
|
| 246 |
# 3. Embedding & LLM loaders
|
| 247 |
# =====================================================================
|
| 248 |
|
| 249 |
+
|
| 250 |
@st.cache_resource
|
| 251 |
def load_embedding_model(model_name):
|
| 252 |
st.info(f"Loading embedding model '{model_name}'...")
|
| 253 |
return SentenceTransformer(model_name)
|
| 254 |
|
| 255 |
+
# --- Added for Pro Version ---
|
| 256 |
@st.cache_resource
|
| 257 |
def load_llm_model():
|
| 258 |
+
"""Loads LlamaCPP quantised model for topic labeling."""
|
| 259 |
+
st.info("Loading Llama-3-8B-Instruct (Quantized)... This may take a moment.")
|
| 260 |
+
model_repo = "NousResearch/Meta-Llama-3-8B-Instruct-GGUF"
|
| 261 |
+
model_file = "Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"
|
|
|
|
| 262 |
try:
|
|
|
|
|
|
|
| 263 |
model_path = hf_hub_download(repo_id=model_repo, filename=model_file)
|
|
|
|
| 264 |
return Llama(model_path=model_path, n_gpu_layers=-1, n_ctx=8192, stop=["Q:", "\n"], verbose=False)
|
| 265 |
except Exception as e:
|
| 266 |
st.error(f"Failed to load LLM: {e}")
|
| 267 |
return None
|
| 268 |
+
# -----------------------------
|
| 269 |
|
| 270 |
@st.cache_data
|
| 271 |
def load_precomputed_data(docs_file, embeddings_file):
|
|
|
|
| 273 |
emb = np.load(embeddings_file, allow_pickle=True)
|
| 274 |
return docs, emb
|
| 275 |
|
| 276 |
+
|
| 277 |
# =====================================================================
|
| 278 |
# 4. Topic modeling function
|
| 279 |
# =====================================================================
|
| 280 |
|
| 281 |
+
|
| 282 |
def get_config_hash(cfg):
|
| 283 |
return json.dumps(cfg, sort_keys=True)
|
| 284 |
|
| 285 |
+
|
| 286 |
@st.cache_data
|
| 287 |
+
def perform_topic_modeling(_docs, _embeddings, config_hash):
|
| 288 |
"""Fit BERTopic using cached result."""
|
| 289 |
+
|
| 290 |
_docs = list(_docs)
|
| 291 |
_embeddings = np.asarray(_embeddings)
|
| 292 |
if _embeddings.dtype == object or _embeddings.ndim != 2:
|
| 293 |
try:
|
| 294 |
_embeddings = np.vstack(_embeddings)
|
| 295 |
except Exception:
|
| 296 |
+
st.error(
|
| 297 |
+
f"Embeddings are invalid (dtype={_embeddings.dtype}, ndim={_embeddings.ndim}). "
|
| 298 |
+
"Please click **Prepare Data** to regenerate."
|
| 299 |
+
)
|
| 300 |
st.stop()
|
| 301 |
_embeddings = np.ascontiguousarray(_embeddings, dtype=np.float32)
|
| 302 |
|
| 303 |
if _embeddings.shape[0] != len(_docs):
|
| 304 |
+
st.error(
|
| 305 |
+
f"Mismatch between docs and embeddings: len(docs)={len(_docs)} vs "
|
| 306 |
+
f"embeddings.shape[0]={_embeddings.shape[0]}. "
|
| 307 |
+
"Delete the cached files for this configuration and regenerate."
|
| 308 |
+
)
|
| 309 |
st.stop()
|
| 310 |
|
| 311 |
config = json.loads(config_hash)
|
| 312 |
|
| 313 |
if "ngram_range" in config["vectorizer_params"]:
|
| 314 |
+
config["vectorizer_params"]["ngram_range"] = tuple(
|
| 315 |
+
config["vectorizer_params"]["ngram_range"]
|
| 316 |
+
)
|
| 317 |
|
| 318 |
+
# --- LLM Representation Setup (Added for Pro Version) ---
|
| 319 |
+
llm = load_llm_model()
|
| 320 |
rep_model = None
|
| 321 |
+
|
| 322 |
+
if llm:
|
| 323 |
+
prompt = """Q:
|
|
|
|
| 324 |
You are an expert in micro-phenomenology. The following documents are reflections from participants about their experience.
|
| 325 |
I have a topic that contains the following documents:
|
| 326 |
[DOCUMENTS]
|
| 327 |
The topic is described by the following keywords: '[KEYWORDS]'.
|
| 328 |
Based on the above information, give a short, informative label (5–10 words).
|
| 329 |
A:"""
|
| 330 |
+
rep_model = {
|
| 331 |
+
"LLM": LlamaCPP(llm, prompt=prompt, nr_docs=25, doc_length=300, tokenizer="whitespace")
|
| 332 |
+
}
|
| 333 |
+
# -----------------------------------------------------
|
|
|
|
| 334 |
|
| 335 |
umap_model = UMAP(random_state=42, metric="cosine", **config["umap_params"])
|
| 336 |
+
hdbscan_model = HDBSCAN(
|
| 337 |
+
metric="euclidean", prediction_data=True, **config["hdbscan_params"]
|
| 338 |
+
)
|
| 339 |
+
vectorizer_model = (
|
| 340 |
+
CountVectorizer(**config["vectorizer_params"])
|
| 341 |
+
if config["use_vectorizer"]
|
| 342 |
+
else None
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
nr_topics_val = (
|
| 346 |
+
None
|
| 347 |
+
if config["bt_params"]["nr_topics"] == "auto"
|
| 348 |
+
else int(config["bt_params"]["nr_topics"])
|
| 349 |
+
)
|
| 350 |
|
| 351 |
topic_model = BERTopic(
|
| 352 |
umap_model=umap_model,
|
| 353 |
hdbscan_model=hdbscan_model,
|
| 354 |
vectorizer_model=vectorizer_model,
|
| 355 |
+
representation_model=rep_model, # <-- Pass LLM representation here
|
| 356 |
top_n_words=config["bt_params"]["top_n_words"],
|
| 357 |
nr_topics=nr_topics_val,
|
| 358 |
verbose=False,
|
|
|
|
| 363 |
|
| 364 |
outlier_pct = 0
|
| 365 |
if -1 in info.Topic.values:
|
| 366 |
+
outlier_pct = (
|
| 367 |
+
info.Count[info.Topic == -1].iloc[0] / info.Count.sum()
|
| 368 |
+
) * 100
|
| 369 |
|
| 370 |
+
# --- Extract Labels (Prefer LLM if available) ---
|
| 371 |
+
if rep_model and "LLM" in topic_model.get_topics(full=True):
|
|
|
|
| 372 |
raw_labels = [label[0][0] for label in topic_model.get_topics(full=True)["LLM"].values()]
|
| 373 |
cleaned_labels = [lbl.split(":")[-1].strip().strip('"').strip(".") for lbl in raw_labels]
|
| 374 |
final_labels = [lbl if lbl else "Unlabelled" for lbl in cleaned_labels]
|
|
|
|
| 375 |
all_labels = [final_labels[topic + topic_model._outliers] if topic != -1 else "Unlabelled" for topic in topics]
|
| 376 |
else:
|
| 377 |
+
# Fallback for when LLM fails or is not present
|
| 378 |
+
topic_info = topic_model.get_topic_info()
|
| 379 |
+
name_map = topic_info.set_index("Topic")["Name"].to_dict()
|
| 380 |
all_labels = [name_map[topic] for topic in topics]
|
| 381 |
+
# -----------------------------------------------
|
| 382 |
|
| 383 |
+
reduced = UMAP(
|
| 384 |
+
n_neighbors=15,
|
| 385 |
+
n_components=2,
|
| 386 |
+
min_dist=0.0,
|
| 387 |
+
metric="cosine",
|
| 388 |
+
random_state=42,
|
| 389 |
+
).fit_transform(_embeddings)
|
| 390 |
|
| 391 |
return topic_model, reduced, all_labels, len(info) - 1, outlier_pct
|
| 392 |
|
| 393 |
+
|
| 394 |
# =====================================================================
|
| 395 |
# 5. CSV → documents → embeddings pipeline
|
| 396 |
# =====================================================================
|
| 397 |
|
| 398 |
+
|
| 399 |
+
def generate_and_save_embeddings(
|
| 400 |
+
csv_path,
|
| 401 |
+
docs_file,
|
| 402 |
+
emb_file,
|
| 403 |
+
selected_embedding_model,
|
| 404 |
+
split_sentences,
|
| 405 |
+
device,
|
| 406 |
+
text_col=None,
|
| 407 |
+
):
|
| 408 |
+
|
| 409 |
+
# ---------------------
|
| 410 |
+
# Load & clean CSV
|
| 411 |
+
# ---------------------
|
| 412 |
st.info(f"Reading and preparing CSV: {csv_path}")
|
| 413 |
df = pd.read_csv(csv_path)
|
| 414 |
|
| 415 |
+
if text_col is not None and text_col in df.columns:
|
| 416 |
+
col = text_col
|
| 417 |
+
else:
|
| 418 |
+
col = _pick_text_column(df)
|
| 419 |
+
|
| 420 |
if col is None:
|
| 421 |
st.error("CSV must contain at least one text column.")
|
| 422 |
return
|
|
|
|
| 429 |
df = df[df["reflection_answer_english"].str.strip() != ""]
|
| 430 |
reports = df["reflection_answer_english"].tolist()
|
| 431 |
|
| 432 |
+
# ---------------------
|
| 433 |
+
# Sentence / report granularity
|
| 434 |
+
# ---------------------
|
| 435 |
if split_sentences:
|
| 436 |
try:
|
| 437 |
sentences = [s for r in reports for s in nltk.sent_tokenize(r)]
|
|
|
|
| 445 |
np.save(docs_file, np.array(docs, dtype=object))
|
| 446 |
st.success(f"Prepared {len(docs)} documents")
|
| 447 |
|
| 448 |
+
# ---------------------
|
| 449 |
+
# Embeddings
|
| 450 |
+
# ---------------------
|
| 451 |
+
st.info(
|
| 452 |
+
f"Encoding {len(docs)} documents with {selected_embedding_model} on {device}"
|
| 453 |
+
)
|
| 454 |
|
| 455 |
+
model = load_embedding_model(selected_embedding_model)
|
|
|
|
| 456 |
|
| 457 |
+
encode_device = None
|
| 458 |
+
batch_size = 32
|
| 459 |
+
if device == "CPU":
|
| 460 |
+
encode_device = "cpu"
|
| 461 |
+
batch_size = 64
|
| 462 |
+
|
| 463 |
+
embeddings = model.encode(
|
| 464 |
+
docs,
|
| 465 |
+
show_progress_bar=True,
|
| 466 |
+
batch_size=batch_size,
|
| 467 |
+
device=encode_device,
|
| 468 |
+
convert_to_numpy=True,
|
| 469 |
+
)
|
| 470 |
embeddings = np.asarray(embeddings, dtype=np.float32)
|
| 471 |
np.save(emb_file, embeddings)
|
| 472 |
|
|
|
|
| 474 |
st.balloons()
|
| 475 |
st.rerun()
|
| 476 |
|
| 477 |
+
|
| 478 |
# =====================================================================
|
| 479 |
# 6. Sidebar — dataset, upload, parameters
|
| 480 |
# =====================================================================
|
| 481 |
|
| 482 |
st.sidebar.header("Data Input Method")
|
| 483 |
|
| 484 |
+
source = st.sidebar.radio(
|
| 485 |
+
"Choose data source",
|
| 486 |
+
("Use preprocessed CSV on server", "Upload my own CSV"),
|
| 487 |
+
index=0,
|
| 488 |
+
key="data_source",
|
| 489 |
+
)
|
| 490 |
|
| 491 |
uploaded_csv_path = None
|
| 492 |
+
CSV_PATH = None # will be set in the chosen branch
|
| 493 |
|
| 494 |
if source == "Use preprocessed CSV on server":
|
| 495 |
available = _list_server_csvs(PROC_DIR)
|
| 496 |
if not available:
|
| 497 |
+
st.info(
|
| 498 |
+
f"No CSVs found in {PROC_DIR}. Switch to 'Upload my own CSV' or change the dataset name."
|
| 499 |
+
)
|
| 500 |
st.stop()
|
| 501 |
+
selected_csv = st.sidebar.selectbox(
|
| 502 |
+
"Choose a preprocessed CSV", available, key="server_csv_select"
|
| 503 |
+
)
|
| 504 |
CSV_PATH = selected_csv
|
| 505 |
else:
|
| 506 |
+
up = st.sidebar.file_uploader(
|
| 507 |
+
"Upload a CSV", type=["csv"], key="upload_csv"
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
st.sidebar.caption(
|
| 511 |
+
"Your CSV should have **one row per report** and at least one text column "
|
| 512 |
+
"(for example `reflection_answer_english`, `reflection_answer`, `text`, `report`, "
|
| 513 |
+
"or any other column containing free text). "
|
| 514 |
+
"Other columns (ID, condition, etc.) are allowed. "
|
| 515 |
+
"After upload, you’ll be able to choose which text column to analyse."
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
|
| 519 |
if up is not None:
|
| 520 |
+
# List of encodings to try:
|
| 521 |
+
# 1. utf-8 (Standard)
|
| 522 |
+
# 2. mac_roman (Fixes the Õ and É issues from Mac Excel)
|
| 523 |
+
# 3. cp1252 (Standard Windows Excel)
|
| 524 |
encodings_to_try = ['utf-8', 'mac_roman', 'cp1252', 'ISO-8859-1']
|
| 525 |
+
|
| 526 |
tmp_df = None
|
| 527 |
+
success_encoding = None
|
| 528 |
+
|
| 529 |
for encoding in encodings_to_try:
|
| 530 |
try:
|
| 531 |
+
up.seek(0) # Always reset to start of file before trying
|
| 532 |
tmp_df = pd.read_csv(up, encoding=encoding)
|
| 533 |
+
success_encoding = encoding
|
| 534 |
+
break # If we get here, it worked, so stop the loop
|
| 535 |
except UnicodeDecodeError:
|
| 536 |
+
continue # If it fails, try the next one
|
| 537 |
+
|
| 538 |
+
if tmp_df is None:
|
| 539 |
+
st.error("Could not decode file. Please save your CSV as 'CSV UTF-8' in Excel.")
|
| 540 |
st.stop()
|
| 541 |
+
|
| 542 |
+
if tmp_df.empty:
|
| 543 |
+
st.error("Uploaded CSV is empty.")
|
| 544 |
+
st.stop()
|
| 545 |
+
|
| 546 |
+
# Optional: Print which encoding worked to the logs (for your info)
|
| 547 |
+
print(f"Successfully loaded CSV using {success_encoding} encoding.")
|
| 548 |
+
|
| 549 |
+
# FIX: Use the original filename to avoid cache collisions
|
| 550 |
+
# We sanitize the name to be safe for file systems
|
| 551 |
safe_filename = _slugify(os.path.splitext(up.name)[0])
|
| 552 |
_cleanup_old_cache(safe_filename)
|
| 553 |
uploaded_csv_path = str((PROC_DIR / f"{safe_filename}.csv").resolve())
|
| 554 |
+
|
| 555 |
tmp_df.to_csv(uploaded_csv_path, index=False)
|
| 556 |
st.success(f"Uploaded CSV saved to {uploaded_csv_path}")
|
| 557 |
CSV_PATH = uploaded_csv_path
|
|
|
|
| 562 |
if CSV_PATH is None:
|
| 563 |
st.stop()
|
| 564 |
|
| 565 |
+
# ---------------------------------------------------------------------
|
| 566 |
# Text column selection
|
| 567 |
+
# ---------------------------------------------------------------------
|
| 568 |
+
|
| 569 |
+
|
| 570 |
@st.cache_data
|
| 571 |
def get_text_columns(csv_path: str) -> list[str]:
|
| 572 |
df_sample = pd.read_csv(csv_path, nrows=2000)
|
| 573 |
return _list_text_columns(df_sample)
|
| 574 |
|
| 575 |
text_columns = get_text_columns(CSV_PATH)
|
| 576 |
+
|
| 577 |
if not text_columns:
|
| 578 |
+
st.error(
|
| 579 |
+
"No columns found in this CSV. At least one column is required."
|
| 580 |
+
)
|
| 581 |
st.stop()
|
| 582 |
|
| 583 |
+
|
| 584 |
+
text_columns = get_text_columns(CSV_PATH)
|
| 585 |
+
if not text_columns:
|
| 586 |
+
st.error(
|
| 587 |
+
"No text-like columns found in this CSV. At least one column must contain text."
|
| 588 |
+
)
|
| 589 |
+
st.stop()
|
| 590 |
+
|
| 591 |
+
# Try to pick a nice default (one of the MOSAIC-ish names) if present
|
| 592 |
try:
|
| 593 |
df_sample = pd.read_csv(CSV_PATH, nrows=2000)
|
| 594 |
preferred = _pick_text_column(df_sample)
|
| 595 |
+
except Exception:
|
| 596 |
+
preferred = None
|
| 597 |
+
|
| 598 |
+
if preferred in text_columns:
|
| 599 |
+
default_idx = text_columns.index(preferred)
|
| 600 |
+
else:
|
| 601 |
+
default_idx = 0
|
| 602 |
|
| 603 |
+
selected_text_column = st.sidebar.selectbox(
|
| 604 |
+
"Text column to analyse",
|
| 605 |
+
text_columns,
|
| 606 |
+
index=default_idx,
|
| 607 |
+
key="text_column_select",
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
# ---------------------------------------------------------------------
|
| 611 |
+
# Data granularity & subsampling
|
| 612 |
+
# ---------------------------------------------------------------------
|
| 613 |
|
|
|
|
| 614 |
st.sidebar.subheader("Data Granularity & Subsampling")
|
| 615 |
+
|
| 616 |
+
selected_granularity = st.sidebar.checkbox(
|
| 617 |
+
"Split reports into sentences", value=True
|
| 618 |
+
)
|
| 619 |
granularity_label = "sentences" if selected_granularity else "reports"
|
| 620 |
+
|
| 621 |
subsample_perc = st.sidebar.slider("Data sampling (%)", 10, 100, 100, 5)
|
| 622 |
|
| 623 |
st.sidebar.markdown("---")
|
| 624 |
|
| 625 |
+
# ---------------------------------------------------------------------
|
| 626 |
+
# Embedding model & device
|
| 627 |
+
# ---------------------------------------------------------------------
|
| 628 |
+
|
| 629 |
st.sidebar.header("Model Selection")
|
| 630 |
+
|
| 631 |
+
selected_embedding_model = st.sidebar.selectbox(
|
| 632 |
+
"Choose an embedding model",
|
| 633 |
+
(
|
| 634 |
+
"BAAI/bge-small-en-v1.5",
|
| 635 |
+
"intfloat/multilingual-e5-large-instruct",
|
| 636 |
+
"Qwen/Qwen3-Embedding-0.6B",
|
| 637 |
+
"sentence-transformers/all-mpnet-base-v2",
|
| 638 |
+
),
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
selected_device = st.sidebar.radio(
|
| 642 |
+
"Processing device",
|
| 643 |
+
["GPU (MPS)", "CPU"],
|
| 644 |
+
index=0,
|
| 645 |
+
)
|
| 646 |
|
| 647 |
# =====================================================================
|
| 648 |
+
# 7. Precompute filenames and pipeline triggers
|
| 649 |
# =====================================================================
|
| 650 |
|
| 651 |
+
|
| 652 |
def get_precomputed_filenames(csv_path, model_name, split_sentences, text_col):
|
| 653 |
base = os.path.splitext(os.path.basename(csv_path))[0]
|
| 654 |
safe_model = re.sub(r"[^a-zA-Z0-9_-]", "_", model_name)
|
| 655 |
suf = "sentences" if split_sentences else "reports"
|
| 656 |
+
|
| 657 |
+
col_suffix = ""
|
| 658 |
+
if text_col:
|
| 659 |
+
safe_col = re.sub(r"[^a-zA-Z0-9_-]", "_", text_col)
|
| 660 |
+
col_suffix = f"_{safe_col}"
|
| 661 |
+
|
| 662 |
return (
|
| 663 |
str(CACHE_DIR / f"precomputed_{base}{col_suffix}_{suf}_docs.npy"),
|
| 664 |
+
str(
|
| 665 |
+
CACHE_DIR
|
| 666 |
+
/ f"precomputed_{base}_{safe_model}{col_suffix}_{suf}_embeddings.npy"
|
| 667 |
+
),
|
| 668 |
)
|
| 669 |
|
|
|
|
| 670 |
|
| 671 |
+
DOCS_FILE, EMBEDDINGS_FILE = get_precomputed_filenames(
|
| 672 |
+
CSV_PATH, selected_embedding_model, selected_granularity, selected_text_column
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
# --- Cache management ---
|
| 676 |
st.sidebar.markdown("### Cache")
|
| 677 |
+
if st.sidebar.button(
|
| 678 |
+
"Clear cached files for this configuration", use_container_width=True
|
| 679 |
+
):
|
| 680 |
+
try:
|
| 681 |
+
for p in (DOCS_FILE, EMBEDDINGS_FILE):
|
| 682 |
+
if os.path.exists(p):
|
| 683 |
+
os.remove(p)
|
| 684 |
+
try:
|
| 685 |
+
load_precomputed_data.clear()
|
| 686 |
+
except Exception:
|
| 687 |
+
pass
|
| 688 |
+
try:
|
| 689 |
+
perform_topic_modeling.clear()
|
| 690 |
+
except Exception:
|
| 691 |
+
pass
|
| 692 |
+
|
| 693 |
+
st.success(
|
| 694 |
+
"Deleted cached docs/embeddings and cleared caches. Click **Prepare Data** again."
|
| 695 |
+
)
|
| 696 |
+
st.rerun()
|
| 697 |
+
except Exception as e:
|
| 698 |
+
st.error(f"Failed to delete cache files: {e}")
|
| 699 |
|
| 700 |
st.sidebar.markdown("---")
|
| 701 |
|
| 702 |
# =====================================================================
|
| 703 |
+
# 8. Prepare Data OR Run Analysis
|
| 704 |
# =====================================================================
|
| 705 |
|
| 706 |
if not os.path.exists(EMBEDDINGS_FILE):
|
| 707 |
+
st.warning(
|
| 708 |
+
f"No precomputed embeddings found for this configuration "
|
| 709 |
+
f"({granularity_label} / {selected_embedding_model} / column '{selected_text_column}')."
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
if st.button("Prepare Data for This Configuration"):
|
| 713 |
+
generate_and_save_embeddings(
|
| 714 |
+
CSV_PATH,
|
| 715 |
+
DOCS_FILE,
|
| 716 |
+
EMBEDDINGS_FILE,
|
| 717 |
+
selected_embedding_model,
|
| 718 |
+
selected_granularity,
|
| 719 |
+
selected_device,
|
| 720 |
+
text_col=selected_text_column,
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
else:
|
| 724 |
+
# Load cached data
|
| 725 |
docs, embeddings = load_precomputed_data(DOCS_FILE, EMBEDDINGS_FILE)
|
| 726 |
+
|
| 727 |
embeddings = np.asarray(embeddings)
|
| 728 |
if embeddings.dtype == object or embeddings.ndim != 2:
|
| 729 |
try:
|
| 730 |
embeddings = np.vstack(embeddings).astype(np.float32)
|
| 731 |
except Exception:
|
| 732 |
+
st.error(
|
| 733 |
+
"Cached embeddings are invalid. Please regenerate them for this configuration."
|
| 734 |
+
)
|
| 735 |
st.stop()
|
| 736 |
|
| 737 |
if subsample_perc < 100:
|
| 738 |
n = int(len(docs) * (subsample_perc / 100))
|
| 739 |
idx = np.random.choice(len(docs), size=n, replace=False)
|
| 740 |
docs = [docs[i] for i in idx]
|
| 741 |
+
embeddings = np.asarray(embeddings)[idx, :]
|
| 742 |
+
st.warning(
|
| 743 |
+
f"Running analysis on {subsample_perc}% subsample ({len(docs)} documents)"
|
| 744 |
+
)
|
| 745 |
|
| 746 |
+
# Dataset summary
|
| 747 |
st.subheader("Dataset summary")
|
| 748 |
n_reports = count_clean_reports(CSV_PATH, selected_text_column)
|
| 749 |
+
unit = "sentences" if selected_granularity else "reports"
|
| 750 |
+
n_units = len(docs)
|
| 751 |
|
| 752 |
+
c1, c2 = st.columns(2)
|
| 753 |
+
c1.metric("Reports in CSV (cleaned)", n_reports)
|
| 754 |
+
c2.metric(f"Units analysed ({unit})", n_units)
|
| 755 |
+
|
| 756 |
+
# --- Parameter controls ---
|
| 757 |
st.sidebar.header("Model Parameters")
|
| 758 |
+
|
| 759 |
use_vectorizer = st.sidebar.checkbox("Use CountVectorizer", value=True)
|
| 760 |
+
|
| 761 |
with st.sidebar.expander("Vectorizer"):
|
| 762 |
+
ng_min = st.slider("Min N-gram", 1, 5, 1)
|
| 763 |
+
ng_max = st.slider("Max N-gram", 1, 5, 2)
|
| 764 |
min_df = st.slider("Min Doc Freq", 1, 50, 1)
|
| 765 |
+
stopwords = st.select_slider(
|
| 766 |
+
"Stopwords", options=[None, "english"], value=None
|
| 767 |
+
)
|
| 768 |
+
|
| 769 |
with st.sidebar.expander("UMAP"):
|
| 770 |
um_n = st.slider("n_neighbors", 2, 50, 15)
|
| 771 |
um_c = st.slider("n_components", 2, 20, 5)
|
| 772 |
um_d = st.slider("min_dist", 0.0, 1.0, 0.0)
|
| 773 |
+
|
| 774 |
with st.sidebar.expander("HDBSCAN"):
|
| 775 |
hs = st.slider("min_cluster_size", 5, 100, 10)
|
| 776 |
hm = st.slider("min_samples", 2, 100, 5)
|
| 777 |
+
|
| 778 |
with st.sidebar.expander("BERTopic"):
|
| 779 |
nr_topics = st.text_input("nr_topics", value="auto")
|
| 780 |
top_n_words = st.slider("top_n_words", 5, 25, 10)
|
|
|
|
| 784 |
"granularity": granularity_label,
|
| 785 |
"subsample_percent": subsample_perc,
|
| 786 |
"use_vectorizer": use_vectorizer,
|
| 787 |
+
"vectorizer_params": {
|
| 788 |
+
"ngram_range": (ng_min, ng_max),
|
| 789 |
+
"min_df": min_df,
|
| 790 |
+
"stop_words": stopwords,
|
| 791 |
+
},
|
| 792 |
+
"umap_params": {
|
| 793 |
+
"n_neighbors": um_n,
|
| 794 |
+
"n_components": um_c,
|
| 795 |
+
"min_dist": um_d,
|
| 796 |
+
},
|
| 797 |
+
"hdbscan_params": {
|
| 798 |
+
"min_cluster_size": hs,
|
| 799 |
+
"min_samples": hm,
|
| 800 |
+
},
|
| 801 |
+
"bt_params": {
|
| 802 |
+
"nr_topics": nr_topics,
|
| 803 |
+
"top_n_words": top_n_words,
|
| 804 |
+
},
|
| 805 |
"text_column": selected_text_column,
|
|
|
|
| 806 |
}
|
| 807 |
|
| 808 |
run_button = st.sidebar.button("Run Analysis", type="primary")
|
|
|
|
| 814 |
|
| 815 |
def load_history():
|
| 816 |
path = HISTORY_FILE
|
| 817 |
+
if not os.path.exists(path):
|
| 818 |
+
return []
|
| 819 |
try:
|
| 820 |
data = json.load(open(path))
|
| 821 |
+
except Exception:
|
| 822 |
+
return []
|
| 823 |
+
for e in data:
|
| 824 |
+
if "outlier_pct" not in e and "outlier_perc" in e:
|
| 825 |
+
e["outlier_pct"] = e.pop("outlier_perc")
|
| 826 |
+
return data
|
| 827 |
|
| 828 |
def save_history(h):
|
| 829 |
json.dump(h, open(HISTORY_FILE, "w"), indent=2)
|
|
|
|
| 832 |
st.session_state.history = load_history()
|
| 833 |
|
| 834 |
if run_button:
|
| 835 |
+
if not isinstance(embeddings, np.ndarray):
|
| 836 |
+
embeddings = np.asarray(embeddings)
|
| 837 |
+
|
| 838 |
+
if embeddings.dtype == object or embeddings.ndim != 2:
|
| 839 |
+
try:
|
| 840 |
+
embeddings = np.vstack(embeddings).astype(np.float32)
|
| 841 |
+
except Exception:
|
| 842 |
+
st.error(
|
| 843 |
+
"Cached embeddings are invalid (object/ragged). Click **Prepare Data** to regenerate."
|
| 844 |
+
)
|
| 845 |
+
st.stop()
|
| 846 |
+
|
| 847 |
+
if embeddings.shape[0] != len(docs):
|
| 848 |
+
st.error(
|
| 849 |
+
f"len(docs)={len(docs)} but embeddings.shape[0]={embeddings.shape[0]}.\n"
|
| 850 |
+
"Likely stale cache (e.g., switched sentences↔reports or model). "
|
| 851 |
+
"Use the **Clear cache** button below and regenerate."
|
| 852 |
+
)
|
| 853 |
+
st.stop()
|
| 854 |
+
|
| 855 |
with st.spinner("Performing topic modeling..."):
|
|
|
|
| 856 |
model, reduced, labels, n_topics, outlier_pct = perform_topic_modeling(
|
| 857 |
+
docs, embeddings, get_config_hash(current_config)
|
| 858 |
)
|
| 859 |
st.session_state.latest_results = (model, reduced, labels)
|
| 860 |
|
|
|
|
| 863 |
"config": current_config,
|
| 864 |
"num_topics": n_topics,
|
| 865 |
"outlier_pct": f"{outlier_pct:.2f}%",
|
| 866 |
+
"llm_labels": [
|
| 867 |
+
name
|
| 868 |
+
for name in model.get_topic_info().Name.values
|
| 869 |
+
if ("Unlabelled" not in name and "outlier" not in name)
|
| 870 |
+
],
|
| 871 |
}
|
| 872 |
st.session_state.history.insert(0, entry)
|
| 873 |
save_history(st.session_state.history)
|
| 874 |
st.rerun()
|
| 875 |
|
| 876 |
+
# --- MAIN TAB ---
|
| 877 |
with main_tab:
|
| 878 |
if "latest_results" in st.session_state:
|
| 879 |
tm, reduced, labs = st.session_state.latest_results
|
| 880 |
+
|
| 881 |
st.subheader("Experiential Topics Visualisation")
|
| 882 |
fig, _ = datamapplot.create_plot(reduced, labs)
|
| 883 |
st.pyplot(fig)
|
| 884 |
+
|
| 885 |
st.subheader("Topic Info")
|
| 886 |
st.dataframe(tm.get_topic_info())
|
| 887 |
+
|
| 888 |
+
st.subheader("Export results (one row per topic)")
|
| 889 |
+
|
| 890 |
full_reps = tm.get_topics(full=True)
|
| 891 |
llm_reps = full_reps.get("LLM", {})
|
| 892 |
+
|
|
|
|
| 893 |
llm_names = {}
|
| 894 |
+
for tid, vals in llm_reps.items():
|
| 895 |
+
try:
|
| 896 |
+
llm_names[tid] = (
|
| 897 |
+
(vals[0][0] or "").strip().strip('"').strip(".")
|
| 898 |
+
)
|
| 899 |
+
except Exception:
|
| 900 |
+
llm_names[tid] = "Unlabelled"
|
| 901 |
+
|
| 902 |
+
if not llm_names:
|
| 903 |
+
st.caption("Note: Using default keyword-based topic names.")
|
| 904 |
+
llm_names = (
|
| 905 |
+
tm.get_topic_info().set_index("Topic")["Name"].to_dict()
|
| 906 |
+
)
|
| 907 |
|
| 908 |
doc_info = tm.get_document_info(docs)[["Document", "Topic"]]
|
| 909 |
+
|
| 910 |
+
include_outliers = st.checkbox(
|
| 911 |
+
"Include outlier topic (-1)", value=False
|
| 912 |
+
)
|
| 913 |
+
if not include_outliers:
|
| 914 |
+
doc_info = doc_info[doc_info["Topic"] != -1]
|
| 915 |
+
|
| 916 |
+
grouped = (
|
| 917 |
+
doc_info.groupby("Topic")["Document"]
|
| 918 |
+
.apply(list)
|
| 919 |
+
.reset_index(name="texts")
|
| 920 |
+
)
|
| 921 |
+
grouped["topic_name"] = grouped["Topic"].map(llm_names).fillna(
|
| 922 |
+
"Unlabelled"
|
| 923 |
+
)
|
| 924 |
+
|
| 925 |
+
export_topics = (
|
| 926 |
+
grouped.rename(columns={"Topic": "topic_id"})[
|
| 927 |
+
["topic_id", "topic_name", "texts"]
|
| 928 |
+
]
|
| 929 |
+
.sort_values("topic_id")
|
| 930 |
+
.reset_index(drop=True)
|
| 931 |
+
)
|
| 932 |
+
|
| 933 |
+
SEP = "\n"
|
| 934 |
+
|
| 935 |
+
export_csv = export_topics.copy()
|
| 936 |
+
export_csv["texts"] = export_csv["texts"].apply(
|
| 937 |
+
lambda lst: SEP.join(map(str, lst))
|
| 938 |
+
)
|
| 939 |
+
|
| 940 |
base = os.path.splitext(os.path.basename(CSV_PATH))[0]
|
| 941 |
gran = "sentences" if selected_granularity else "reports"
|
| 942 |
+
csv_name = f"topics_{base}_{gran}.csv"
|
| 943 |
+
jsonl_name = f"topics_{base}_{gran}.jsonl"
|
| 944 |
+
csv_path = (EVAL_DIR / csv_name).resolve()
|
| 945 |
+
jsonl_path = (EVAL_DIR / jsonl_name).resolve()
|
| 946 |
+
|
| 947 |
+
cL, cC, cR = st.columns(3)
|
| 948 |
+
|
| 949 |
+
with cL:
|
| 950 |
+
if st.button("Save CSV to eval/", use_container_width=True):
|
| 951 |
+
try:
|
| 952 |
+
export_csv.to_csv(csv_path, index=False)
|
| 953 |
+
st.success(f"Saved CSV → {csv_path}")
|
| 954 |
+
except Exception as e:
|
| 955 |
+
st.error(f"Failed to save CSV: {e}")
|
| 956 |
+
|
| 957 |
+
with cC:
|
| 958 |
+
if st.button("Save JSONL to eval/", use_container_width=True):
|
| 959 |
+
try:
|
| 960 |
+
with open(jsonl_path, "w", encoding="utf-8") as f:
|
| 961 |
+
for _, row in export_topics.iterrows():
|
| 962 |
+
rec = {
|
| 963 |
+
"topic_id": int(row["topic_id"]),
|
| 964 |
+
"topic_name": row["topic_name"],
|
| 965 |
+
"texts": list(map(str, row["texts"])),
|
| 966 |
+
}
|
| 967 |
+
f.write(
|
| 968 |
+
json.dumps(rec, ensure_ascii=False) + "\n"
|
| 969 |
+
)
|
| 970 |
+
st.success(f"Saved JSONL → {jsonl_path}")
|
| 971 |
+
except Exception as e:
|
| 972 |
+
st.error(f"Failed to save JSONL: {e}")
|
| 973 |
+
|
| 974 |
+
with cR:
|
| 975 |
+
|
| 976 |
+
# Create a Long Format DataFrame (One row per sentence)
|
| 977 |
+
# This ensures NO text is hidden due to Excel cell limits
|
| 978 |
+
long_format_df = doc_info.copy()
|
| 979 |
+
long_format_df["Topic Name"] = long_format_df["Topic"].map(llm_names).fillna("Unlabelled")
|
| 980 |
+
|
| 981 |
+
# Reorder columns for clarity
|
| 982 |
+
long_format_df = long_format_df[["Topic", "Topic Name", "Document"]]
|
| 983 |
+
|
| 984 |
+
# Define filename
|
| 985 |
+
long_csv_name = f"all_sentences_{base}_{gran}.csv"
|
| 986 |
+
|
| 987 |
+
st.download_button(
|
| 988 |
+
"Download All Sentences (Long Format)",
|
| 989 |
+
data=long_format_df.to_csv(index=False).encode("utf-8-sig"),
|
| 990 |
+
file_name=long_csv_name,
|
| 991 |
+
mime="text/csv",
|
| 992 |
+
use_container_width=True,
|
| 993 |
+
help="Download a CSV with one row per sentence. Best for checking exactly which sentences belong to which topic."
|
| 994 |
+
)
|
| 995 |
+
|
| 996 |
+
# st.caption("Preview (one row per topic)")
|
| 997 |
+
st.dataframe(export_csv)
|
| 998 |
+
|
| 999 |
+
else:
|
| 1000 |
+
st.info("Click 'Run Analysis' to begin.")
|
| 1001 |
|
| 1002 |
+
# --- HISTORY TAB ---
|
| 1003 |
with history_tab:
|
| 1004 |
st.subheader("Run History")
|
| 1005 |
if not st.session_state.history:
|
|
|
|
| 1008 |
for i, entry in enumerate(st.session_state.history):
|
| 1009 |
with st.expander(f"Run {i+1} — {entry['timestamp']}"):
|
| 1010 |
st.write(f"**Topics:** {entry['num_topics']}")
|
| 1011 |
+
st.write(
|
| 1012 |
+
f"**Outliers:** {entry.get('outlier_pct', entry.get('outlier_perc', 'N/A'))}"
|
| 1013 |
+
)
|
| 1014 |
st.write("**Topic Labels:**")
|
| 1015 |
+
st.write(entry["llm_labels"])
|
| 1016 |
+
with st.expander("Show full configuration"):
|
| 1017 |
+
st.json(entry["config"])
|