romybeaute commited on
Commit
6b6764a
·
verified ·
1 Parent(s): 45756cc

updated to match previous app version

Browse files
Files changed (1) hide show
  1. app_with_LLM.py +541 -170
app_with_LLM.py CHANGED
@@ -1,7 +1,8 @@
1
  """
2
- File: app.py
3
  Description: Streamlit app for advanced topic modeling on Innerspeech dataset
4
- with BERTopic, UMAP, HDBSCAN. Supports conditional LLM (LlamaCPP) execution.
 
5
  Last Modified: 08/12/2025
6
  """
7
 
@@ -18,31 +19,12 @@ import re
18
  import os
19
  import nltk
20
  import json
21
- import logging
22
 
23
- # Set up logging
24
- logging.basicConfig(level=logging.INFO)
25
- logger = logging.getLogger(__name__)
26
-
27
- # Check for LLM enablement via Environment Variable
28
- # set ENABLE_LLM to "True" in Hugging Face Space Settings to activate
29
- ENABLE_LLM = os.getenv("ENABLE_LLM", "False").lower() in ("true", "1", "yes")
30
-
31
- # Conditional Imports for Heavy LLM Libraries
32
- LLM_AVAILABLE = False
33
- if ENABLE_LLM:
34
- try:
35
- from llama_cpp import Llama
36
- from bertopic.representation import LlamaCPP
37
- from huggingface_hub import hf_hub_download
38
- LLM_AVAILABLE = True
39
- logger.info("LLM modules imported successfully.")
40
- except ImportError as e:
41
- logger.warning(f"ENABLE_LLM is True, but imports failed: {e}. Falling back to Lite mode.")
42
- ENABLE_LLM = False
43
-
44
- # Standard Imports
45
- from mosaic.path_utils import CFG, raw_path, proc_path, eval_path, project_root # type: ignore
46
 
47
  # BERTopic stack
48
  from bertopic import BERTopic
@@ -57,6 +39,7 @@ from hdbscan import HDBSCAN
57
  import datamapplot
58
  import matplotlib.pyplot as plt
59
 
 
60
  # =====================================================================
61
  # NLTK setup
62
  # =====================================================================
@@ -65,6 +48,7 @@ NLTK_DATA_DIR = "/usr/local/share/nltk_data"
65
  if NLTK_DATA_DIR not in nltk.data.path:
66
  nltk.data.path.append(NLTK_DATA_DIR)
67
 
 
68
  for resource in ("punkt_tab", "punkt"):
69
  try:
70
  nltk.data.find(f"tokenizers/{resource}")
@@ -81,10 +65,12 @@ for resource in ("punkt_tab", "punkt"):
81
  try:
82
  from mosaic.path_utils import CFG, raw_path, proc_path, eval_path, project_root # type: ignore
83
  except Exception:
 
84
  def _env(key: str, default: str) -> Path:
85
  val = os.getenv(key, default)
86
  return Path(val).expanduser().resolve()
87
 
 
88
  _DATA_ROOT = _env("MOSAIC_DATA", str(Path(__file__).parent / "data"))
89
  _BOX_ROOT = _env("MOSAIC_BOX", str(Path(__file__).parent / "data" / "raw"))
90
  _EVAL_ROOT = _env("MOSAIC_EVAL", str(Path(__file__).parent / "eval"))
@@ -107,43 +93,63 @@ except Exception:
107
  def eval_path(*parts: str) -> Path:
108
  return _EVAL_ROOT.joinpath(*parts)
109
 
 
110
  # =====================================================================
111
  # 0. Constants & Helper Functions
112
  # =====================================================================
113
 
 
114
  def _slugify(s: str) -> str:
115
  s = s.strip()
116
  s = re.sub(r"[^A-Za-z0-9._-]+", "_", s)
117
  return s or "DATASET"
118
 
119
  def _cleanup_old_cache(current_slug: str):
 
120
  if not CACHE_DIR.exists():
121
  return
 
122
  removed_count = 0
 
123
  for p in CACHE_DIR.glob("precomputed_*.npy"):
 
124
  if current_slug not in p.name:
125
  try:
126
- p.unlink()
127
  removed_count += 1
128
  except Exception as e:
129
  print(f"Error deleting {p.name}: {e}")
 
130
  if removed_count > 0:
131
  print(f"Auto-cleanup: Removed {removed_count} old cache files.")
132
 
133
  ACCEPTABLE_TEXT_COLUMNS = [
134
- "reflection_answer_english", "reflection_answer", "text", "report",
 
 
 
135
  ]
136
 
 
137
  def _pick_text_column(df: pd.DataFrame) -> str | None:
 
138
  for col in ACCEPTABLE_TEXT_COLUMNS:
139
  if col in df.columns:
140
  return col
141
  return None
142
 
 
143
  def _list_text_columns(df: pd.DataFrame) -> list[str]:
 
 
 
 
144
  return list(df.columns)
145
 
 
 
146
  def _set_from_env_or_secrets(key: str):
 
147
  if os.getenv(key):
148
  return
149
  try:
@@ -153,44 +159,58 @@ def _set_from_env_or_secrets(key: str):
153
  if val:
154
  os.environ[key] = str(val)
155
 
 
 
156
  for _k in ("MOSAIC_DATA", "MOSAIC_BOX"):
157
  _set_from_env_or_secrets(_k)
158
 
 
159
  @st.cache_data
160
  def count_clean_reports(csv_path: str, text_col: str | None = None) -> int:
 
161
  df = pd.read_csv(csv_path)
162
- col = text_col if (text_col and text_col in df.columns) else _pick_text_column(df)
 
 
 
 
 
163
  if col is None:
164
  return 0
 
165
  if col != "reflection_answer_english":
166
  df = df.rename(columns={col: "reflection_answer_english"})
 
167
  df.dropna(subset=["reflection_answer_english"], inplace=True)
168
  df["reflection_answer_english"] = df["reflection_answer_english"].astype(str)
169
  df = df[df["reflection_answer_english"].str.strip() != ""]
170
  return len(df)
171
 
 
172
  # =====================================================================
173
  # 1. Streamlit app setup
174
  # =====================================================================
175
 
176
- st.set_page_config(page_title="MOSAIC Dashboard", layout="wide")
177
- st.title("Mapping of Subjective Accounts into Interpreted Clusters (MOSAIC)")
178
-
179
- mode_status = "🟢 Pro Mode (LLM Enabled)" if ENABLE_LLM else "🟡 Lite Mode (LLM Disabled)"
180
- st.caption(f"Current Runtime Status: **{mode_status}**")
181
 
182
  st.markdown(
183
  """
184
  _If you use this tool in your research, please cite the following paper:_\n
185
- **Beauté, R., et al. (2025).** **Mapping of Subjective Accounts into Interpreted Clusters (MOSAIC)** https://arxiv.org/abs/2502.18318
186
  """
187
  )
188
 
189
  # =====================================================================
190
- # 2. Dataset paths
191
  # =====================================================================
192
 
193
- ds_input = st.sidebar.text_input("Project/Dataset name", value="MOSAIC", key="dataset_name_input")
 
 
194
  DATASET_DIR = _slugify(ds_input).upper()
195
 
196
  RAW_DIR = raw_path(DATASET_DIR)
@@ -202,36 +222,50 @@ PROC_DIR.mkdir(parents=True, exist_ok=True)
202
  CACHE_DIR.mkdir(parents=True, exist_ok=True)
203
  EVAL_DIR.mkdir(parents=True, exist_ok=True)
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  def _list_server_csvs(proc_dir: Path) -> list[str]:
206
  return [str(p) for p in sorted(proc_dir.glob("*.csv"))]
207
 
 
 
208
  HISTORY_FILE = str(PROC_DIR / "run_history.json")
209
 
210
  # =====================================================================
211
  # 3. Embedding & LLM loaders
212
  # =====================================================================
213
 
 
214
  @st.cache_resource
215
  def load_embedding_model(model_name):
216
  st.info(f"Loading embedding model '{model_name}'...")
217
  return SentenceTransformer(model_name)
218
 
 
219
  @st.cache_resource
220
  def load_llm_model():
221
- """Loads LlamaCPP model only if ENABLE_LLM is True."""
222
- if not ENABLE_LLM or not LLM_AVAILABLE:
223
- return None
224
-
225
- st.info("Loading Llama-3-8B-Instruct (Quantized)...")
226
  try:
227
- model_repo = "NousResearch/Meta-Llama-3-8B-Instruct-GGUF"
228
- model_file = "Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"
229
  model_path = hf_hub_download(repo_id=model_repo, filename=model_file)
230
- # n_gpu_layers=-1 attempts to offload all to GPU
231
  return Llama(model_path=model_path, n_gpu_layers=-1, n_ctx=8192, stop=["Q:", "\n"], verbose=False)
232
  except Exception as e:
233
  st.error(f"Failed to load LLM: {e}")
234
  return None
 
235
 
236
  @st.cache_data
237
  def load_precomputed_data(docs_file, embeddings_file):
@@ -239,63 +273,86 @@ def load_precomputed_data(docs_file, embeddings_file):
239
  emb = np.load(embeddings_file, allow_pickle=True)
240
  return docs, emb
241
 
 
242
  # =====================================================================
243
  # 4. Topic modeling function
244
  # =====================================================================
245
 
 
246
  def get_config_hash(cfg):
247
  return json.dumps(cfg, sort_keys=True)
248
 
 
249
  @st.cache_data
250
- def perform_topic_modeling(_docs, _embeddings, config_hash, use_llm_flag=False):
251
  """Fit BERTopic using cached result."""
 
252
  _docs = list(_docs)
253
  _embeddings = np.asarray(_embeddings)
254
  if _embeddings.dtype == object or _embeddings.ndim != 2:
255
  try:
256
  _embeddings = np.vstack(_embeddings)
257
  except Exception:
258
- st.error("Embeddings are invalid. Regenerate data.")
 
 
 
259
  st.stop()
260
  _embeddings = np.ascontiguousarray(_embeddings, dtype=np.float32)
261
 
262
  if _embeddings.shape[0] != len(_docs):
263
- st.error("Mismatch between docs and embeddings.")
 
 
 
 
264
  st.stop()
265
 
266
  config = json.loads(config_hash)
267
 
268
  if "ngram_range" in config["vectorizer_params"]:
269
- config["vectorizer_params"]["ngram_range"] = tuple(config["vectorizer_params"]["ngram_range"])
 
 
270
 
271
- # --- Representation Model Logic ---
 
272
  rep_model = None
273
- if use_llm_flag and LLM_AVAILABLE:
274
- llm = load_llm_model()
275
- if llm:
276
- prompt = """Q:
277
  You are an expert in micro-phenomenology. The following documents are reflections from participants about their experience.
278
  I have a topic that contains the following documents:
279
  [DOCUMENTS]
280
  The topic is described by the following keywords: '[KEYWORDS]'.
281
  Based on the above information, give a short, informative label (5–10 words).
282
  A:"""
283
- rep_model = {
284
- "LLM": LlamaCPP(llm, prompt=prompt, nr_docs=25, doc_length=300, tokenizer="whitespace")
285
- }
286
- else:
287
- print("LLM requested but load failed; falling back to default representation.")
288
 
289
  umap_model = UMAP(random_state=42, metric="cosine", **config["umap_params"])
290
- hdbscan_model = HDBSCAN(metric="euclidean", prediction_data=True, **config["hdbscan_params"])
291
- vectorizer_model = CountVectorizer(**config["vectorizer_params"]) if config["use_vectorizer"] else None
292
- nr_topics_val = None if config["bt_params"]["nr_topics"] == "auto" else int(config["bt_params"]["nr_topics"])
 
 
 
 
 
 
 
 
 
 
 
293
 
294
  topic_model = BERTopic(
295
  umap_model=umap_model,
296
  hdbscan_model=hdbscan_model,
297
  vectorizer_model=vectorizer_model,
298
- representation_model=rep_model,
299
  top_n_words=config["bt_params"]["top_n_words"],
300
  nr_topics=nr_topics_val,
301
  verbose=False,
@@ -306,34 +363,60 @@ A:"""
306
 
307
  outlier_pct = 0
308
  if -1 in info.Topic.values:
309
- outlier_pct = (info.Count[info.Topic == -1].iloc[0] / info.Count.sum()) * 100
 
 
310
 
311
- # Label extraction
312
- if use_llm_flag and rep_model and "LLM" in topic_model.get_topics(full=True):
313
- # Extract LLM labels if available
314
  raw_labels = [label[0][0] for label in topic_model.get_topics(full=True)["LLM"].values()]
315
  cleaned_labels = [lbl.split(":")[-1].strip().strip('"').strip(".") for lbl in raw_labels]
316
  final_labels = [lbl if lbl else "Unlabelled" for lbl in cleaned_labels]
317
- # Map back to docs
318
  all_labels = [final_labels[topic + topic_model._outliers] if topic != -1 else "Unlabelled" for topic in topics]
319
  else:
320
- # Default labels
321
- name_map = info.set_index("Topic")["Name"].to_dict()
 
322
  all_labels = [name_map[topic] for topic in topics]
 
323
 
324
- reduced = UMAP(n_neighbors=15, n_components=2, min_dist=0.0, metric="cosine", random_state=42).fit_transform(_embeddings)
 
 
 
 
 
 
325
 
326
  return topic_model, reduced, all_labels, len(info) - 1, outlier_pct
327
 
 
328
  # =====================================================================
329
  # 5. CSV → documents → embeddings pipeline
330
  # =====================================================================
331
 
332
- def generate_and_save_embeddings(csv_path, docs_file, emb_file, selected_embedding_model, split_sentences, device, text_col=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  st.info(f"Reading and preparing CSV: {csv_path}")
334
  df = pd.read_csv(csv_path)
335
 
336
- col = text_col if (text_col and text_col in df.columns) else _pick_text_column(df)
 
 
 
 
337
  if col is None:
338
  st.error("CSV must contain at least one text column.")
339
  return
@@ -346,6 +429,9 @@ def generate_and_save_embeddings(csv_path, docs_file, emb_file, selected_embeddi
346
  df = df[df["reflection_answer_english"].str.strip() != ""]
347
  reports = df["reflection_answer_english"].tolist()
348
 
 
 
 
349
  if split_sentences:
350
  try:
351
  sentences = [s for r in reports for s in nltk.sent_tokenize(r)]
@@ -359,13 +445,28 @@ def generate_and_save_embeddings(csv_path, docs_file, emb_file, selected_embeddi
359
  np.save(docs_file, np.array(docs, dtype=object))
360
  st.success(f"Prepared {len(docs)} documents")
361
 
362
- st.info(f"Encoding {len(docs)} documents with {selected_embedding_model} on {device}")
363
- model = load_embedding_model(selected_embedding_model)
 
 
 
 
364
 
365
- encode_device = "cpu" if device == "CPU" else None
366
- batch_size = 64 if device == "CPU" else 32
367
 
368
- embeddings = model.encode(docs, show_progress_bar=True, batch_size=batch_size, device=encode_device, convert_to_numpy=True)
 
 
 
 
 
 
 
 
 
 
 
 
369
  embeddings = np.asarray(embeddings, dtype=np.float32)
370
  np.save(emb_file, embeddings)
371
 
@@ -373,43 +474,84 @@ def generate_and_save_embeddings(csv_path, docs_file, emb_file, selected_embeddi
373
  st.balloons()
374
  st.rerun()
375
 
 
376
  # =====================================================================
377
  # 6. Sidebar — dataset, upload, parameters
378
  # =====================================================================
379
 
380
  st.sidebar.header("Data Input Method")
381
 
382
- source = st.sidebar.radio("Choose data source", ("Use preprocessed CSV on server", "Upload my own CSV"), index=0, key="data_source")
 
 
 
 
 
383
 
384
  uploaded_csv_path = None
385
- CSV_PATH = None
386
 
387
  if source == "Use preprocessed CSV on server":
388
  available = _list_server_csvs(PROC_DIR)
389
  if not available:
390
- st.info(f"No CSVs found in {PROC_DIR}. Upload a CSV.")
 
 
391
  st.stop()
392
- selected_csv = st.sidebar.selectbox("Choose a preprocessed CSV", available, key="server_csv_select")
 
 
393
  CSV_PATH = selected_csv
394
  else:
395
- up = st.sidebar.file_uploader("Upload a CSV", type=["csv"], key="upload_csv")
 
 
 
 
 
 
 
 
 
 
 
 
396
  if up is not None:
 
 
 
 
397
  encodings_to_try = ['utf-8', 'mac_roman', 'cp1252', 'ISO-8859-1']
 
398
  tmp_df = None
 
 
399
  for encoding in encodings_to_try:
400
  try:
401
- up.seek(0)
402
  tmp_df = pd.read_csv(up, encoding=encoding)
403
- break
 
404
  except UnicodeDecodeError:
405
- continue
406
- if tmp_df is None or tmp_df.empty:
407
- st.error("Could not decode file.")
 
408
  st.stop()
409
-
 
 
 
 
 
 
 
 
 
410
  safe_filename = _slugify(os.path.splitext(up.name)[0])
411
  _cleanup_old_cache(safe_filename)
412
  uploaded_csv_path = str((PROC_DIR / f"{safe_filename}.csv").resolve())
 
413
  tmp_df.to_csv(uploaded_csv_path, index=False)
414
  st.success(f"Uploaded CSV saved to {uploaded_csv_path}")
415
  CSV_PATH = uploaded_csv_path
@@ -420,115 +562,219 @@ else:
420
  if CSV_PATH is None:
421
  st.stop()
422
 
 
423
  # Text column selection
 
 
 
424
  @st.cache_data
425
  def get_text_columns(csv_path: str) -> list[str]:
426
  df_sample = pd.read_csv(csv_path, nrows=2000)
427
  return _list_text_columns(df_sample)
428
 
429
  text_columns = get_text_columns(CSV_PATH)
 
430
  if not text_columns:
431
- st.error("No columns found in this CSV.")
 
 
432
  st.stop()
433
 
434
- preferred = None
 
 
 
 
 
 
 
 
435
  try:
436
  df_sample = pd.read_csv(CSV_PATH, nrows=2000)
437
  preferred = _pick_text_column(df_sample)
438
- except Exception: pass
 
 
 
 
 
 
439
 
440
- default_idx = text_columns.index(preferred) if preferred in text_columns else 0
441
- selected_text_column = st.sidebar.selectbox("Text column to analyse", text_columns, index=default_idx, key="text_column_select")
 
 
 
 
 
 
 
 
442
 
443
- # Data Granularity
444
  st.sidebar.subheader("Data Granularity & Subsampling")
445
- selected_granularity = st.sidebar.checkbox("Split reports into sentences", value=True)
 
 
 
446
  granularity_label = "sentences" if selected_granularity else "reports"
 
447
  subsample_perc = st.sidebar.slider("Data sampling (%)", 10, 100, 100, 5)
448
 
449
  st.sidebar.markdown("---")
450
 
451
- # Model Selection
 
 
 
452
  st.sidebar.header("Model Selection")
453
- selected_embedding_model = st.sidebar.selectbox("Choose an embedding model", (
454
- "BAAI/bge-small-en-v1.5",
455
- "intfloat/multilingual-e5-large-instruct",
456
- "Qwen/Qwen3-Embedding-0.6B",
457
- "sentence-transformers/all-mpnet-base-v2",
458
- ))
459
- selected_device = st.sidebar.radio("Processing device", ["GPU (MPS)", "CPU"], index=0)
 
 
 
 
 
 
 
 
 
460
 
461
  # =====================================================================
462
- # 7. Precompute filenames
463
  # =====================================================================
464
 
 
465
  def get_precomputed_filenames(csv_path, model_name, split_sentences, text_col):
466
  base = os.path.splitext(os.path.basename(csv_path))[0]
467
  safe_model = re.sub(r"[^a-zA-Z0-9_-]", "_", model_name)
468
  suf = "sentences" if split_sentences else "reports"
469
- col_suffix = f"_{re.sub(r'[^a-zA-Z0-9_-]', '_', text_col)}" if text_col else ""
 
 
 
 
 
470
  return (
471
  str(CACHE_DIR / f"precomputed_{base}{col_suffix}_{suf}_docs.npy"),
472
- str(CACHE_DIR / f"precomputed_{base}_{safe_model}{col_suffix}_{suf}_embeddings.npy"),
 
 
 
473
  )
474
 
475
- DOCS_FILE, EMBEDDINGS_FILE = get_precomputed_filenames(CSV_PATH, selected_embedding_model, selected_granularity, selected_text_column)
476
 
 
 
 
 
 
477
  st.sidebar.markdown("### Cache")
478
- if st.sidebar.button("Clear cached files for this configuration", use_container_width=True):
479
- for p in (DOCS_FILE, EMBEDDINGS_FILE):
480
- if os.path.exists(p):
481
- os.remove(p)
482
- load_precomputed_data.clear()
483
- perform_topic_modeling.clear()
484
- st.success("Cache cleared.")
485
- st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
 
487
  st.sidebar.markdown("---")
488
 
489
  # =====================================================================
490
- # 8. Run Analysis
491
  # =====================================================================
492
 
493
  if not os.path.exists(EMBEDDINGS_FILE):
494
- st.warning(f"No precomputed embeddings found for this configuration.")
 
 
 
 
495
  if st.button("Prepare Data for This Configuration"):
496
- generate_and_save_embeddings(CSV_PATH, DOCS_FILE, EMBEDDINGS_FILE, selected_embedding_model, selected_granularity, selected_device, text_col=selected_text_column)
 
 
 
 
 
 
 
 
 
497
  else:
 
498
  docs, embeddings = load_precomputed_data(DOCS_FILE, EMBEDDINGS_FILE)
 
499
  embeddings = np.asarray(embeddings)
500
  if embeddings.dtype == object or embeddings.ndim != 2:
501
  try:
502
  embeddings = np.vstack(embeddings).astype(np.float32)
503
  except Exception:
504
- st.error("Cached embeddings are invalid. Regenerate.")
 
 
505
  st.stop()
506
 
507
  if subsample_perc < 100:
508
  n = int(len(docs) * (subsample_perc / 100))
509
  idx = np.random.choice(len(docs), size=n, replace=False)
510
  docs = [docs[i] for i in idx]
511
- embeddings = embeddings[idx, :]
512
- st.warning(f"Running analysis on {subsample_perc}% subsample ({len(docs)} documents)")
 
 
513
 
 
514
  st.subheader("Dataset summary")
515
  n_reports = count_clean_reports(CSV_PATH, selected_text_column)
516
- st.metric("Units analysed", len(docs))
 
517
 
518
- # Parameters
 
 
 
 
519
  st.sidebar.header("Model Parameters")
 
520
  use_vectorizer = st.sidebar.checkbox("Use CountVectorizer", value=True)
 
521
  with st.sidebar.expander("Vectorizer"):
522
- ng_min, ng_max = st.slider("N-gram Range", 1, 5, (1, 2))
 
523
  min_df = st.slider("Min Doc Freq", 1, 50, 1)
524
- stopwords = st.select_slider("Stopwords", options=[None, "english"], value=None)
 
 
 
525
  with st.sidebar.expander("UMAP"):
526
  um_n = st.slider("n_neighbors", 2, 50, 15)
527
  um_c = st.slider("n_components", 2, 20, 5)
528
  um_d = st.slider("min_dist", 0.0, 1.0, 0.0)
 
529
  with st.sidebar.expander("HDBSCAN"):
530
  hs = st.slider("min_cluster_size", 5, 100, 10)
531
  hm = st.slider("min_samples", 2, 100, 5)
 
532
  with st.sidebar.expander("BERTopic"):
533
  nr_topics = st.text_input("nr_topics", value="auto")
534
  top_n_words = st.slider("top_n_words", 5, 25, 10)
@@ -538,12 +784,25 @@ else:
538
  "granularity": granularity_label,
539
  "subsample_percent": subsample_perc,
540
  "use_vectorizer": use_vectorizer,
541
- "vectorizer_params": {"ngram_range": (ng_min, ng_max), "min_df": min_df, "stop_words": stopwords},
542
- "umap_params": {"n_neighbors": um_n, "n_components": um_c, "min_dist": um_d},
543
- "hdbscan_params": {"min_cluster_size": hs, "min_samples": hm},
544
- "bt_params": {"nr_topics": nr_topics, "top_n_words": top_n_words},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
  "text_column": selected_text_column,
546
- "llm_enabled": ENABLE_LLM
547
  }
548
 
549
  run_button = st.sidebar.button("Run Analysis", type="primary")
@@ -555,13 +814,16 @@ else:
555
 
556
  def load_history():
557
  path = HISTORY_FILE
558
- if not os.path.exists(path): return []
 
559
  try:
560
  data = json.load(open(path))
561
- for e in data:
562
- if "outlier_pct" not in e and "outlier_perc" in e: e["outlier_pct"] = e.pop("outlier_perc")
563
- return data
564
- except Exception: return []
 
 
565
 
566
  def save_history(h):
567
  json.dump(h, open(HISTORY_FILE, "w"), indent=2)
@@ -570,10 +832,29 @@ else:
570
  st.session_state.history = load_history()
571
 
572
  if run_button:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573
  with st.spinner("Performing topic modeling..."):
574
- # Pass the global ENABLE_LLM flag to the cached function
575
  model, reduced, labels, n_topics, outlier_pct = perform_topic_modeling(
576
- docs, embeddings, get_config_hash(current_config), use_llm_flag=ENABLE_LLM
577
  )
578
  st.session_state.latest_results = (model, reduced, labels)
579
 
@@ -582,56 +863,143 @@ else:
582
  "config": current_config,
583
  "num_topics": n_topics,
584
  "outlier_pct": f"{outlier_pct:.2f}%",
585
- "llm_labels": [name for name in model.get_topic_info().Name.values if ("Unlabelled" not in name and "outlier" not in name)],
 
 
 
 
586
  }
587
  st.session_state.history.insert(0, entry)
588
  save_history(st.session_state.history)
589
  st.rerun()
590
 
 
591
  with main_tab:
592
  if "latest_results" in st.session_state:
593
  tm, reduced, labs = st.session_state.latest_results
 
594
  st.subheader("Experiential Topics Visualisation")
595
  fig, _ = datamapplot.create_plot(reduced, labs)
596
  st.pyplot(fig)
 
597
  st.subheader("Topic Info")
598
  st.dataframe(tm.get_topic_info())
599
-
600
- # Export Logic
601
- st.subheader("Export results")
602
  full_reps = tm.get_topics(full=True)
603
  llm_reps = full_reps.get("LLM", {})
604
-
605
- # Determine how to name topics based on what mode we ran
606
  llm_names = {}
607
- if ENABLE_LLM and llm_reps:
608
- for tid, vals in llm_reps.items():
609
- try:
610
- llm_names[tid] = (vals[0][0] or "").strip().strip('"').strip(".")
611
- except Exception:
612
- llm_names[tid] = "Unlabelled"
613
- else:
614
- llm_names = tm.get_topic_info().set_index("Topic")["Name"].to_dict()
 
 
 
 
 
615
 
616
  doc_info = tm.get_document_info(docs)[["Document", "Topic"]]
617
-
618
- # --- Long Format Export (One row per sentence) ---
619
- long_format_df = doc_info.copy()
620
- long_format_df["Topic Name"] = long_format_df["Topic"].map(llm_names).fillna("Unlabelled")
621
- long_format_df = long_format_df[["Topic", "Topic Name", "Document"]]
622
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
623
  base = os.path.splitext(os.path.basename(CSV_PATH))[0]
624
  gran = "sentences" if selected_granularity else "reports"
625
- long_csv_name = f"all_sentences_{base}_{gran}.csv"
626
-
627
- st.download_button(
628
- "Download All Sentences (Long Format)",
629
- data=long_format_df.to_csv(index=False).encode("utf-8-sig"),
630
- file_name=long_csv_name,
631
- mime="text/csv",
632
- use_container_width=True
633
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634
 
 
635
  with history_tab:
636
  st.subheader("Run History")
637
  if not st.session_state.history:
@@ -640,7 +1008,10 @@ else:
640
  for i, entry in enumerate(st.session_state.history):
641
  with st.expander(f"Run {i+1} — {entry['timestamp']}"):
642
  st.write(f"**Topics:** {entry['num_topics']}")
643
- st.write(f"**Outliers:** {entry.get('outlier_pct', 'N/A')}")
 
 
644
  st.write("**Topic Labels:**")
645
- st.write(entry.get("llm_labels", []))
646
- st.json(entry["config"])
 
 
1
  """
2
+ File: app_with_LLM.py
3
  Description: Streamlit app for advanced topic modeling on Innerspeech dataset
4
+ with BERTopic, UMAP, HDBSCAN.
5
+ **PRO VERSION: LLM (LlamaCPP) Enabled**
6
  Last Modified: 08/12/2025
7
  """
8
 
 
19
  import os
20
  import nltk
21
  import json
 
22
 
23
+ # --- LLM Specific Imports (Added for Pro Version) ---
24
+ from llama_cpp import Llama
25
+ from bertopic.representation import LlamaCPP
26
+ from huggingface_hub import hf_hub_download
27
+ # ----------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # BERTopic stack
30
  from bertopic import BERTopic
 
39
  import datamapplot
40
  import matplotlib.pyplot as plt
41
 
42
+
43
  # =====================================================================
44
  # NLTK setup
45
  # =====================================================================
 
48
  if NLTK_DATA_DIR not in nltk.data.path:
49
  nltk.data.path.append(NLTK_DATA_DIR)
50
 
51
+ # Try to ensure both punkt_tab (new NLTK) and punkt (old NLTK) are available
52
  for resource in ("punkt_tab", "punkt"):
53
  try:
54
  nltk.data.find(f"tokenizers/{resource}")
 
65
  try:
66
  from mosaic.path_utils import CFG, raw_path, proc_path, eval_path, project_root # type: ignore
67
  except Exception:
68
+ # Minimal stand-in so the app works anywhere (Streamlit Cloud, local without MOSAIC, etc.)
69
  def _env(key: str, default: str) -> Path:
70
  val = os.getenv(key, default)
71
  return Path(val).expanduser().resolve()
72
 
73
+ # Defaults: app-local data/ eval/ that are safe on Cloud
74
  _DATA_ROOT = _env("MOSAIC_DATA", str(Path(__file__).parent / "data"))
75
  _BOX_ROOT = _env("MOSAIC_BOX", str(Path(__file__).parent / "data" / "raw"))
76
  _EVAL_ROOT = _env("MOSAIC_EVAL", str(Path(__file__).parent / "eval"))
 
93
  def eval_path(*parts: str) -> Path:
94
  return _EVAL_ROOT.joinpath(*parts)
95
 
96
+
97
  # =====================================================================
98
  # 0. Constants & Helper Functions
99
  # =====================================================================
100
 
101
+
102
  def _slugify(s: str) -> str:
103
  s = s.strip()
104
  s = re.sub(r"[^A-Za-z0-9._-]+", "_", s)
105
  return s or "DATASET"
106
 
107
  def _cleanup_old_cache(current_slug: str):
108
+ """Deletes precomputed .npy files that do not match the current dataset slug."""
109
  if not CACHE_DIR.exists():
110
  return
111
+
112
  removed_count = 0
113
+ # Iterate over all precomputed files
114
  for p in CACHE_DIR.glob("precomputed_*.npy"):
115
+ # If the file belongs to a different dataset (doesn't contain the new slug)
116
  if current_slug not in p.name:
117
  try:
118
+ p.unlink() # Delete file
119
  removed_count += 1
120
  except Exception as e:
121
  print(f"Error deleting {p.name}: {e}")
122
+
123
  if removed_count > 0:
124
  print(f"Auto-cleanup: Removed {removed_count} old cache files.")
125
 
126
  ACCEPTABLE_TEXT_COLUMNS = [
127
+ "reflection_answer_english",
128
+ "reflection_answer",
129
+ "text",
130
+ "report",
131
  ]
132
 
133
+
134
  def _pick_text_column(df: pd.DataFrame) -> str | None:
135
+ """Return the first matching *preferred* text column name if present."""
136
  for col in ACCEPTABLE_TEXT_COLUMNS:
137
  if col in df.columns:
138
  return col
139
  return None
140
 
141
+
142
  def _list_text_columns(df: pd.DataFrame) -> list[str]:
143
+ """
144
+ Return all columns; we’ll cast the chosen one to string later.
145
+ This makes the selector work with any column name / dtype.
146
+ """
147
  return list(df.columns)
148
 
149
+
150
+
151
  def _set_from_env_or_secrets(key: str):
152
+ """Allow hosting: value can come from environment or from Streamlit secrets."""
153
  if os.getenv(key):
154
  return
155
  try:
 
159
  if val:
160
  os.environ[key] = str(val)
161
 
162
+
163
+ # Enable both MOSAIC_DATA and MOSAIC_BOX automatically
164
  for _k in ("MOSAIC_DATA", "MOSAIC_BOX"):
165
  _set_from_env_or_secrets(_k)
166
 
167
+
168
  @st.cache_data
169
  def count_clean_reports(csv_path: str, text_col: str | None = None) -> int:
170
+ """Count non-empty reports in the chosen text column."""
171
  df = pd.read_csv(csv_path)
172
+
173
+ if text_col is not None and text_col in df.columns:
174
+ col = text_col
175
+ else:
176
+ col = _pick_text_column(df)
177
+
178
  if col is None:
179
  return 0
180
+
181
  if col != "reflection_answer_english":
182
  df = df.rename(columns={col: "reflection_answer_english"})
183
+
184
  df.dropna(subset=["reflection_answer_english"], inplace=True)
185
  df["reflection_answer_english"] = df["reflection_answer_english"].astype(str)
186
  df = df[df["reflection_answer_english"].str.strip() != ""]
187
  return len(df)
188
 
189
+
190
  # =====================================================================
191
  # 1. Streamlit app setup
192
  # =====================================================================
193
 
194
+ st.set_page_config(page_title="MOSAIC Dashboard (Pro)", layout="wide")
195
+ st.title(
196
+ "Mapping of Subjective Accounts into Interpreted Clusters (MOSAIC): "
197
+ "Topic Modelling Dashboard (Pro Version)"
198
+ )
199
 
200
  st.markdown(
201
  """
202
  _If you use this tool in your research, please cite the following paper:_\n
203
+ **Beauté, R., et al. (2025).** **Mapping of Subjective Accounts into Interpreted Clusters (MOSAIC): Topic Modelling and LLM applied to Stroboscopic Phenomenology** https://arxiv.org/abs/2502.18318
204
  """
205
  )
206
 
207
  # =====================================================================
208
+ # 2. Dataset paths (using MOSAIC structure)
209
  # =====================================================================
210
 
211
+ ds_input = st.sidebar.text_input(
212
+ "Project/Dataset name", value="MOSAIC", key="dataset_name_input"
213
+ )
214
  DATASET_DIR = _slugify(ds_input).upper()
215
 
216
  RAW_DIR = raw_path(DATASET_DIR)
 
222
  CACHE_DIR.mkdir(parents=True, exist_ok=True)
223
  EVAL_DIR.mkdir(parents=True, exist_ok=True)
224
 
225
+ with st.sidebar.expander("About the dataset name", expanded=False):
226
+ st.markdown(
227
+ f"""
228
+ - The name above is converted to **UPPER CASE** and used as a folder name.
229
+ - If the folder doesn’t exist, it will be **created**:
230
+ - Preprocessed CSVs: `{PROC_DIR}`
231
+ - Exports (results): `{EVAL_DIR}`
232
+ - If you choose **Use preprocessed CSV on server**, I’ll list CSVs in `{PROC_DIR}`.
233
+ - If you **upload** a CSV, it will be saved to `{PROC_DIR}/uploaded.csv`.
234
+ """.strip()
235
+ )
236
+
237
+
238
  def _list_server_csvs(proc_dir: Path) -> list[str]:
239
  return [str(p) for p in sorted(proc_dir.glob("*.csv"))]
240
 
241
+
242
+ DATASETS = None # keep name for clarity; we’ll fill it when rendering the sidebar
243
  HISTORY_FILE = str(PROC_DIR / "run_history.json")
244
 
245
  # =====================================================================
246
  # 3. Embedding & LLM loaders
247
  # =====================================================================
248
 
249
+
250
  @st.cache_resource
251
  def load_embedding_model(model_name):
252
  st.info(f"Loading embedding model '{model_name}'...")
253
  return SentenceTransformer(model_name)
254
 
255
+ # --- Added for Pro Version ---
256
  @st.cache_resource
257
  def load_llm_model():
258
+ """Loads LlamaCPP quantised model for topic labeling."""
259
+ st.info("Loading Llama-3-8B-Instruct (Quantized)... This may take a moment.")
260
+ model_repo = "NousResearch/Meta-Llama-3-8B-Instruct-GGUF"
261
+ model_file = "Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"
 
262
  try:
 
 
263
  model_path = hf_hub_download(repo_id=model_repo, filename=model_file)
 
264
  return Llama(model_path=model_path, n_gpu_layers=-1, n_ctx=8192, stop=["Q:", "\n"], verbose=False)
265
  except Exception as e:
266
  st.error(f"Failed to load LLM: {e}")
267
  return None
268
+ # -----------------------------
269
 
270
  @st.cache_data
271
  def load_precomputed_data(docs_file, embeddings_file):
 
273
  emb = np.load(embeddings_file, allow_pickle=True)
274
  return docs, emb
275
 
276
+
277
  # =====================================================================
278
  # 4. Topic modeling function
279
  # =====================================================================
280
 
281
+
282
  def get_config_hash(cfg):
283
  return json.dumps(cfg, sort_keys=True)
284
 
285
+
286
  @st.cache_data
287
+ def perform_topic_modeling(_docs, _embeddings, config_hash):
288
  """Fit BERTopic using cached result."""
289
+
290
  _docs = list(_docs)
291
  _embeddings = np.asarray(_embeddings)
292
  if _embeddings.dtype == object or _embeddings.ndim != 2:
293
  try:
294
  _embeddings = np.vstack(_embeddings)
295
  except Exception:
296
+ st.error(
297
+ f"Embeddings are invalid (dtype={_embeddings.dtype}, ndim={_embeddings.ndim}). "
298
+ "Please click **Prepare Data** to regenerate."
299
+ )
300
  st.stop()
301
  _embeddings = np.ascontiguousarray(_embeddings, dtype=np.float32)
302
 
303
  if _embeddings.shape[0] != len(_docs):
304
+ st.error(
305
+ f"Mismatch between docs and embeddings: len(docs)={len(_docs)} vs "
306
+ f"embeddings.shape[0]={_embeddings.shape[0]}. "
307
+ "Delete the cached files for this configuration and regenerate."
308
+ )
309
  st.stop()
310
 
311
  config = json.loads(config_hash)
312
 
313
  if "ngram_range" in config["vectorizer_params"]:
314
+ config["vectorizer_params"]["ngram_range"] = tuple(
315
+ config["vectorizer_params"]["ngram_range"]
316
+ )
317
 
318
+ # --- LLM Representation Setup (Added for Pro Version) ---
319
+ llm = load_llm_model()
320
  rep_model = None
321
+
322
+ if llm:
323
+ prompt = """Q:
 
324
  You are an expert in micro-phenomenology. The following documents are reflections from participants about their experience.
325
  I have a topic that contains the following documents:
326
  [DOCUMENTS]
327
  The topic is described by the following keywords: '[KEYWORDS]'.
328
  Based on the above information, give a short, informative label (5–10 words).
329
  A:"""
330
+ rep_model = {
331
+ "LLM": LlamaCPP(llm, prompt=prompt, nr_docs=25, doc_length=300, tokenizer="whitespace")
332
+ }
333
+ # -----------------------------------------------------
 
334
 
335
  umap_model = UMAP(random_state=42, metric="cosine", **config["umap_params"])
336
+ hdbscan_model = HDBSCAN(
337
+ metric="euclidean", prediction_data=True, **config["hdbscan_params"]
338
+ )
339
+ vectorizer_model = (
340
+ CountVectorizer(**config["vectorizer_params"])
341
+ if config["use_vectorizer"]
342
+ else None
343
+ )
344
+
345
+ nr_topics_val = (
346
+ None
347
+ if config["bt_params"]["nr_topics"] == "auto"
348
+ else int(config["bt_params"]["nr_topics"])
349
+ )
350
 
351
  topic_model = BERTopic(
352
  umap_model=umap_model,
353
  hdbscan_model=hdbscan_model,
354
  vectorizer_model=vectorizer_model,
355
+ representation_model=rep_model, # <-- Pass LLM representation here
356
  top_n_words=config["bt_params"]["top_n_words"],
357
  nr_topics=nr_topics_val,
358
  verbose=False,
 
363
 
364
  outlier_pct = 0
365
  if -1 in info.Topic.values:
366
+ outlier_pct = (
367
+ info.Count[info.Topic == -1].iloc[0] / info.Count.sum()
368
+ ) * 100
369
 
370
+ # --- Extract Labels (Prefer LLM if available) ---
371
+ if rep_model and "LLM" in topic_model.get_topics(full=True):
 
372
  raw_labels = [label[0][0] for label in topic_model.get_topics(full=True)["LLM"].values()]
373
  cleaned_labels = [lbl.split(":")[-1].strip().strip('"').strip(".") for lbl in raw_labels]
374
  final_labels = [lbl if lbl else "Unlabelled" for lbl in cleaned_labels]
 
375
  all_labels = [final_labels[topic + topic_model._outliers] if topic != -1 else "Unlabelled" for topic in topics]
376
  else:
377
+ # Fallback for when LLM fails or is not present
378
+ topic_info = topic_model.get_topic_info()
379
+ name_map = topic_info.set_index("Topic")["Name"].to_dict()
380
  all_labels = [name_map[topic] for topic in topics]
381
+ # -----------------------------------------------
382
 
383
+ reduced = UMAP(
384
+ n_neighbors=15,
385
+ n_components=2,
386
+ min_dist=0.0,
387
+ metric="cosine",
388
+ random_state=42,
389
+ ).fit_transform(_embeddings)
390
 
391
  return topic_model, reduced, all_labels, len(info) - 1, outlier_pct
392
 
393
+
394
  # =====================================================================
395
  # 5. CSV → documents → embeddings pipeline
396
  # =====================================================================
397
 
398
+
399
+ def generate_and_save_embeddings(
400
+ csv_path,
401
+ docs_file,
402
+ emb_file,
403
+ selected_embedding_model,
404
+ split_sentences,
405
+ device,
406
+ text_col=None,
407
+ ):
408
+
409
+ # ---------------------
410
+ # Load & clean CSV
411
+ # ---------------------
412
  st.info(f"Reading and preparing CSV: {csv_path}")
413
  df = pd.read_csv(csv_path)
414
 
415
+ if text_col is not None and text_col in df.columns:
416
+ col = text_col
417
+ else:
418
+ col = _pick_text_column(df)
419
+
420
  if col is None:
421
  st.error("CSV must contain at least one text column.")
422
  return
 
429
  df = df[df["reflection_answer_english"].str.strip() != ""]
430
  reports = df["reflection_answer_english"].tolist()
431
 
432
+ # ---------------------
433
+ # Sentence / report granularity
434
+ # ---------------------
435
  if split_sentences:
436
  try:
437
  sentences = [s for r in reports for s in nltk.sent_tokenize(r)]
 
445
  np.save(docs_file, np.array(docs, dtype=object))
446
  st.success(f"Prepared {len(docs)} documents")
447
 
448
+ # ---------------------
449
+ # Embeddings
450
+ # ---------------------
451
+ st.info(
452
+ f"Encoding {len(docs)} documents with {selected_embedding_model} on {device}"
453
+ )
454
 
455
+ model = load_embedding_model(selected_embedding_model)
 
456
 
457
+ encode_device = None
458
+ batch_size = 32
459
+ if device == "CPU":
460
+ encode_device = "cpu"
461
+ batch_size = 64
462
+
463
+ embeddings = model.encode(
464
+ docs,
465
+ show_progress_bar=True,
466
+ batch_size=batch_size,
467
+ device=encode_device,
468
+ convert_to_numpy=True,
469
+ )
470
  embeddings = np.asarray(embeddings, dtype=np.float32)
471
  np.save(emb_file, embeddings)
472
 
 
474
  st.balloons()
475
  st.rerun()
476
 
477
+
478
  # =====================================================================
479
  # 6. Sidebar — dataset, upload, parameters
480
  # =====================================================================
481
 
482
  st.sidebar.header("Data Input Method")
483
 
484
+ source = st.sidebar.radio(
485
+ "Choose data source",
486
+ ("Use preprocessed CSV on server", "Upload my own CSV"),
487
+ index=0,
488
+ key="data_source",
489
+ )
490
 
491
  uploaded_csv_path = None
492
+ CSV_PATH = None # will be set in the chosen branch
493
 
494
  if source == "Use preprocessed CSV on server":
495
  available = _list_server_csvs(PROC_DIR)
496
  if not available:
497
+ st.info(
498
+ f"No CSVs found in {PROC_DIR}. Switch to 'Upload my own CSV' or change the dataset name."
499
+ )
500
  st.stop()
501
+ selected_csv = st.sidebar.selectbox(
502
+ "Choose a preprocessed CSV", available, key="server_csv_select"
503
+ )
504
  CSV_PATH = selected_csv
505
  else:
506
+ up = st.sidebar.file_uploader(
507
+ "Upload a CSV", type=["csv"], key="upload_csv"
508
+ )
509
+
510
+ st.sidebar.caption(
511
+ "Your CSV should have **one row per report** and at least one text column "
512
+ "(for example `reflection_answer_english`, `reflection_answer`, `text`, `report`, "
513
+ "or any other column containing free text). "
514
+ "Other columns (ID, condition, etc.) are allowed. "
515
+ "After upload, you’ll be able to choose which text column to analyse."
516
+ )
517
+
518
+
519
  if up is not None:
520
+ # List of encodings to try:
521
+ # 1. utf-8 (Standard)
522
+ # 2. mac_roman (Fixes the Õ and É issues from Mac Excel)
523
+ # 3. cp1252 (Standard Windows Excel)
524
  encodings_to_try = ['utf-8', 'mac_roman', 'cp1252', 'ISO-8859-1']
525
+
526
  tmp_df = None
527
+ success_encoding = None
528
+
529
  for encoding in encodings_to_try:
530
  try:
531
+ up.seek(0) # Always reset to start of file before trying
532
  tmp_df = pd.read_csv(up, encoding=encoding)
533
+ success_encoding = encoding
534
+ break # If we get here, it worked, so stop the loop
535
  except UnicodeDecodeError:
536
+ continue # If it fails, try the next one
537
+
538
+ if tmp_df is None:
539
+ st.error("Could not decode file. Please save your CSV as 'CSV UTF-8' in Excel.")
540
  st.stop()
541
+
542
+ if tmp_df.empty:
543
+ st.error("Uploaded CSV is empty.")
544
+ st.stop()
545
+
546
+ # Optional: Print which encoding worked to the logs (for your info)
547
+ print(f"Successfully loaded CSV using {success_encoding} encoding.")
548
+
549
+ # FIX: Use the original filename to avoid cache collisions
550
+ # We sanitize the name to be safe for file systems
551
  safe_filename = _slugify(os.path.splitext(up.name)[0])
552
  _cleanup_old_cache(safe_filename)
553
  uploaded_csv_path = str((PROC_DIR / f"{safe_filename}.csv").resolve())
554
+
555
  tmp_df.to_csv(uploaded_csv_path, index=False)
556
  st.success(f"Uploaded CSV saved to {uploaded_csv_path}")
557
  CSV_PATH = uploaded_csv_path
 
562
  if CSV_PATH is None:
563
  st.stop()
564
 
565
+ # ---------------------------------------------------------------------
566
  # Text column selection
567
+ # ---------------------------------------------------------------------
568
+
569
+
570
  @st.cache_data
571
  def get_text_columns(csv_path: str) -> list[str]:
572
  df_sample = pd.read_csv(csv_path, nrows=2000)
573
  return _list_text_columns(df_sample)
574
 
575
  text_columns = get_text_columns(CSV_PATH)
576
+
577
  if not text_columns:
578
+ st.error(
579
+ "No columns found in this CSV. At least one column is required."
580
+ )
581
  st.stop()
582
 
583
+
584
+ text_columns = get_text_columns(CSV_PATH)
585
+ if not text_columns:
586
+ st.error(
587
+ "No text-like columns found in this CSV. At least one column must contain text."
588
+ )
589
+ st.stop()
590
+
591
+ # Try to pick a nice default (one of the MOSAIC-ish names) if present
592
  try:
593
  df_sample = pd.read_csv(CSV_PATH, nrows=2000)
594
  preferred = _pick_text_column(df_sample)
595
+ except Exception:
596
+ preferred = None
597
+
598
+ if preferred in text_columns:
599
+ default_idx = text_columns.index(preferred)
600
+ else:
601
+ default_idx = 0
602
 
603
+ selected_text_column = st.sidebar.selectbox(
604
+ "Text column to analyse",
605
+ text_columns,
606
+ index=default_idx,
607
+ key="text_column_select",
608
+ )
609
+
610
+ # ---------------------------------------------------------------------
611
+ # Data granularity & subsampling
612
+ # ---------------------------------------------------------------------
613
 
 
614
  st.sidebar.subheader("Data Granularity & Subsampling")
615
+
616
+ selected_granularity = st.sidebar.checkbox(
617
+ "Split reports into sentences", value=True
618
+ )
619
  granularity_label = "sentences" if selected_granularity else "reports"
620
+
621
  subsample_perc = st.sidebar.slider("Data sampling (%)", 10, 100, 100, 5)
622
 
623
  st.sidebar.markdown("---")
624
 
625
+ # ---------------------------------------------------------------------
626
+ # Embedding model & device
627
+ # ---------------------------------------------------------------------
628
+
629
  st.sidebar.header("Model Selection")
630
+
631
+ selected_embedding_model = st.sidebar.selectbox(
632
+ "Choose an embedding model",
633
+ (
634
+ "BAAI/bge-small-en-v1.5",
635
+ "intfloat/multilingual-e5-large-instruct",
636
+ "Qwen/Qwen3-Embedding-0.6B",
637
+ "sentence-transformers/all-mpnet-base-v2",
638
+ ),
639
+ )
640
+
641
+ selected_device = st.sidebar.radio(
642
+ "Processing device",
643
+ ["GPU (MPS)", "CPU"],
644
+ index=0,
645
+ )
646
 
647
  # =====================================================================
648
+ # 7. Precompute filenames and pipeline triggers
649
  # =====================================================================
650
 
651
+
652
  def get_precomputed_filenames(csv_path, model_name, split_sentences, text_col):
653
  base = os.path.splitext(os.path.basename(csv_path))[0]
654
  safe_model = re.sub(r"[^a-zA-Z0-9_-]", "_", model_name)
655
  suf = "sentences" if split_sentences else "reports"
656
+
657
+ col_suffix = ""
658
+ if text_col:
659
+ safe_col = re.sub(r"[^a-zA-Z0-9_-]", "_", text_col)
660
+ col_suffix = f"_{safe_col}"
661
+
662
  return (
663
  str(CACHE_DIR / f"precomputed_{base}{col_suffix}_{suf}_docs.npy"),
664
+ str(
665
+ CACHE_DIR
666
+ / f"precomputed_{base}_{safe_model}{col_suffix}_{suf}_embeddings.npy"
667
+ ),
668
  )
669
 
 
670
 
671
+ DOCS_FILE, EMBEDDINGS_FILE = get_precomputed_filenames(
672
+ CSV_PATH, selected_embedding_model, selected_granularity, selected_text_column
673
+ )
674
+
675
+ # --- Cache management ---
676
  st.sidebar.markdown("### Cache")
677
+ if st.sidebar.button(
678
+ "Clear cached files for this configuration", use_container_width=True
679
+ ):
680
+ try:
681
+ for p in (DOCS_FILE, EMBEDDINGS_FILE):
682
+ if os.path.exists(p):
683
+ os.remove(p)
684
+ try:
685
+ load_precomputed_data.clear()
686
+ except Exception:
687
+ pass
688
+ try:
689
+ perform_topic_modeling.clear()
690
+ except Exception:
691
+ pass
692
+
693
+ st.success(
694
+ "Deleted cached docs/embeddings and cleared caches. Click **Prepare Data** again."
695
+ )
696
+ st.rerun()
697
+ except Exception as e:
698
+ st.error(f"Failed to delete cache files: {e}")
699
 
700
  st.sidebar.markdown("---")
701
 
702
  # =====================================================================
703
+ # 8. Prepare Data OR Run Analysis
704
  # =====================================================================
705
 
706
  if not os.path.exists(EMBEDDINGS_FILE):
707
+ st.warning(
708
+ f"No precomputed embeddings found for this configuration "
709
+ f"({granularity_label} / {selected_embedding_model} / column '{selected_text_column}')."
710
+ )
711
+
712
  if st.button("Prepare Data for This Configuration"):
713
+ generate_and_save_embeddings(
714
+ CSV_PATH,
715
+ DOCS_FILE,
716
+ EMBEDDINGS_FILE,
717
+ selected_embedding_model,
718
+ selected_granularity,
719
+ selected_device,
720
+ text_col=selected_text_column,
721
+ )
722
+
723
  else:
724
+ # Load cached data
725
  docs, embeddings = load_precomputed_data(DOCS_FILE, EMBEDDINGS_FILE)
726
+
727
  embeddings = np.asarray(embeddings)
728
  if embeddings.dtype == object or embeddings.ndim != 2:
729
  try:
730
  embeddings = np.vstack(embeddings).astype(np.float32)
731
  except Exception:
732
+ st.error(
733
+ "Cached embeddings are invalid. Please regenerate them for this configuration."
734
+ )
735
  st.stop()
736
 
737
  if subsample_perc < 100:
738
  n = int(len(docs) * (subsample_perc / 100))
739
  idx = np.random.choice(len(docs), size=n, replace=False)
740
  docs = [docs[i] for i in idx]
741
+ embeddings = np.asarray(embeddings)[idx, :]
742
+ st.warning(
743
+ f"Running analysis on {subsample_perc}% subsample ({len(docs)} documents)"
744
+ )
745
 
746
+ # Dataset summary
747
  st.subheader("Dataset summary")
748
  n_reports = count_clean_reports(CSV_PATH, selected_text_column)
749
+ unit = "sentences" if selected_granularity else "reports"
750
+ n_units = len(docs)
751
 
752
+ c1, c2 = st.columns(2)
753
+ c1.metric("Reports in CSV (cleaned)", n_reports)
754
+ c2.metric(f"Units analysed ({unit})", n_units)
755
+
756
+ # --- Parameter controls ---
757
  st.sidebar.header("Model Parameters")
758
+
759
  use_vectorizer = st.sidebar.checkbox("Use CountVectorizer", value=True)
760
+
761
  with st.sidebar.expander("Vectorizer"):
762
+ ng_min = st.slider("Min N-gram", 1, 5, 1)
763
+ ng_max = st.slider("Max N-gram", 1, 5, 2)
764
  min_df = st.slider("Min Doc Freq", 1, 50, 1)
765
+ stopwords = st.select_slider(
766
+ "Stopwords", options=[None, "english"], value=None
767
+ )
768
+
769
  with st.sidebar.expander("UMAP"):
770
  um_n = st.slider("n_neighbors", 2, 50, 15)
771
  um_c = st.slider("n_components", 2, 20, 5)
772
  um_d = st.slider("min_dist", 0.0, 1.0, 0.0)
773
+
774
  with st.sidebar.expander("HDBSCAN"):
775
  hs = st.slider("min_cluster_size", 5, 100, 10)
776
  hm = st.slider("min_samples", 2, 100, 5)
777
+
778
  with st.sidebar.expander("BERTopic"):
779
  nr_topics = st.text_input("nr_topics", value="auto")
780
  top_n_words = st.slider("top_n_words", 5, 25, 10)
 
784
  "granularity": granularity_label,
785
  "subsample_percent": subsample_perc,
786
  "use_vectorizer": use_vectorizer,
787
+ "vectorizer_params": {
788
+ "ngram_range": (ng_min, ng_max),
789
+ "min_df": min_df,
790
+ "stop_words": stopwords,
791
+ },
792
+ "umap_params": {
793
+ "n_neighbors": um_n,
794
+ "n_components": um_c,
795
+ "min_dist": um_d,
796
+ },
797
+ "hdbscan_params": {
798
+ "min_cluster_size": hs,
799
+ "min_samples": hm,
800
+ },
801
+ "bt_params": {
802
+ "nr_topics": nr_topics,
803
+ "top_n_words": top_n_words,
804
+ },
805
  "text_column": selected_text_column,
 
806
  }
807
 
808
  run_button = st.sidebar.button("Run Analysis", type="primary")
 
814
 
815
  def load_history():
816
  path = HISTORY_FILE
817
+ if not os.path.exists(path):
818
+ return []
819
  try:
820
  data = json.load(open(path))
821
+ except Exception:
822
+ return []
823
+ for e in data:
824
+ if "outlier_pct" not in e and "outlier_perc" in e:
825
+ e["outlier_pct"] = e.pop("outlier_perc")
826
+ return data
827
 
828
  def save_history(h):
829
  json.dump(h, open(HISTORY_FILE, "w"), indent=2)
 
832
  st.session_state.history = load_history()
833
 
834
  if run_button:
835
+ if not isinstance(embeddings, np.ndarray):
836
+ embeddings = np.asarray(embeddings)
837
+
838
+ if embeddings.dtype == object or embeddings.ndim != 2:
839
+ try:
840
+ embeddings = np.vstack(embeddings).astype(np.float32)
841
+ except Exception:
842
+ st.error(
843
+ "Cached embeddings are invalid (object/ragged). Click **Prepare Data** to regenerate."
844
+ )
845
+ st.stop()
846
+
847
+ if embeddings.shape[0] != len(docs):
848
+ st.error(
849
+ f"len(docs)={len(docs)} but embeddings.shape[0]={embeddings.shape[0]}.\n"
850
+ "Likely stale cache (e.g., switched sentences↔reports or model). "
851
+ "Use the **Clear cache** button below and regenerate."
852
+ )
853
+ st.stop()
854
+
855
  with st.spinner("Performing topic modeling..."):
 
856
  model, reduced, labels, n_topics, outlier_pct = perform_topic_modeling(
857
+ docs, embeddings, get_config_hash(current_config)
858
  )
859
  st.session_state.latest_results = (model, reduced, labels)
860
 
 
863
  "config": current_config,
864
  "num_topics": n_topics,
865
  "outlier_pct": f"{outlier_pct:.2f}%",
866
+ "llm_labels": [
867
+ name
868
+ for name in model.get_topic_info().Name.values
869
+ if ("Unlabelled" not in name and "outlier" not in name)
870
+ ],
871
  }
872
  st.session_state.history.insert(0, entry)
873
  save_history(st.session_state.history)
874
  st.rerun()
875
 
876
+ # --- MAIN TAB ---
877
  with main_tab:
878
  if "latest_results" in st.session_state:
879
  tm, reduced, labs = st.session_state.latest_results
880
+
881
  st.subheader("Experiential Topics Visualisation")
882
  fig, _ = datamapplot.create_plot(reduced, labs)
883
  st.pyplot(fig)
884
+
885
  st.subheader("Topic Info")
886
  st.dataframe(tm.get_topic_info())
887
+
888
+ st.subheader("Export results (one row per topic)")
889
+
890
  full_reps = tm.get_topics(full=True)
891
  llm_reps = full_reps.get("LLM", {})
892
+
 
893
  llm_names = {}
894
+ for tid, vals in llm_reps.items():
895
+ try:
896
+ llm_names[tid] = (
897
+ (vals[0][0] or "").strip().strip('"').strip(".")
898
+ )
899
+ except Exception:
900
+ llm_names[tid] = "Unlabelled"
901
+
902
+ if not llm_names:
903
+ st.caption("Note: Using default keyword-based topic names.")
904
+ llm_names = (
905
+ tm.get_topic_info().set_index("Topic")["Name"].to_dict()
906
+ )
907
 
908
  doc_info = tm.get_document_info(docs)[["Document", "Topic"]]
909
+
910
+ include_outliers = st.checkbox(
911
+ "Include outlier topic (-1)", value=False
912
+ )
913
+ if not include_outliers:
914
+ doc_info = doc_info[doc_info["Topic"] != -1]
915
+
916
+ grouped = (
917
+ doc_info.groupby("Topic")["Document"]
918
+ .apply(list)
919
+ .reset_index(name="texts")
920
+ )
921
+ grouped["topic_name"] = grouped["Topic"].map(llm_names).fillna(
922
+ "Unlabelled"
923
+ )
924
+
925
+ export_topics = (
926
+ grouped.rename(columns={"Topic": "topic_id"})[
927
+ ["topic_id", "topic_name", "texts"]
928
+ ]
929
+ .sort_values("topic_id")
930
+ .reset_index(drop=True)
931
+ )
932
+
933
+ SEP = "\n"
934
+
935
+ export_csv = export_topics.copy()
936
+ export_csv["texts"] = export_csv["texts"].apply(
937
+ lambda lst: SEP.join(map(str, lst))
938
+ )
939
+
940
  base = os.path.splitext(os.path.basename(CSV_PATH))[0]
941
  gran = "sentences" if selected_granularity else "reports"
942
+ csv_name = f"topics_{base}_{gran}.csv"
943
+ jsonl_name = f"topics_{base}_{gran}.jsonl"
944
+ csv_path = (EVAL_DIR / csv_name).resolve()
945
+ jsonl_path = (EVAL_DIR / jsonl_name).resolve()
946
+
947
+ cL, cC, cR = st.columns(3)
948
+
949
+ with cL:
950
+ if st.button("Save CSV to eval/", use_container_width=True):
951
+ try:
952
+ export_csv.to_csv(csv_path, index=False)
953
+ st.success(f"Saved CSV → {csv_path}")
954
+ except Exception as e:
955
+ st.error(f"Failed to save CSV: {e}")
956
+
957
+ with cC:
958
+ if st.button("Save JSONL to eval/", use_container_width=True):
959
+ try:
960
+ with open(jsonl_path, "w", encoding="utf-8") as f:
961
+ for _, row in export_topics.iterrows():
962
+ rec = {
963
+ "topic_id": int(row["topic_id"]),
964
+ "topic_name": row["topic_name"],
965
+ "texts": list(map(str, row["texts"])),
966
+ }
967
+ f.write(
968
+ json.dumps(rec, ensure_ascii=False) + "\n"
969
+ )
970
+ st.success(f"Saved JSONL → {jsonl_path}")
971
+ except Exception as e:
972
+ st.error(f"Failed to save JSONL: {e}")
973
+
974
+ with cR:
975
+
976
+ # Create a Long Format DataFrame (One row per sentence)
977
+ # This ensures NO text is hidden due to Excel cell limits
978
+ long_format_df = doc_info.copy()
979
+ long_format_df["Topic Name"] = long_format_df["Topic"].map(llm_names).fillna("Unlabelled")
980
+
981
+ # Reorder columns for clarity
982
+ long_format_df = long_format_df[["Topic", "Topic Name", "Document"]]
983
+
984
+ # Define filename
985
+ long_csv_name = f"all_sentences_{base}_{gran}.csv"
986
+
987
+ st.download_button(
988
+ "Download All Sentences (Long Format)",
989
+ data=long_format_df.to_csv(index=False).encode("utf-8-sig"),
990
+ file_name=long_csv_name,
991
+ mime="text/csv",
992
+ use_container_width=True,
993
+ help="Download a CSV with one row per sentence. Best for checking exactly which sentences belong to which topic."
994
+ )
995
+
996
+ # st.caption("Preview (one row per topic)")
997
+ st.dataframe(export_csv)
998
+
999
+ else:
1000
+ st.info("Click 'Run Analysis' to begin.")
1001
 
1002
+ # --- HISTORY TAB ---
1003
  with history_tab:
1004
  st.subheader("Run History")
1005
  if not st.session_state.history:
 
1008
  for i, entry in enumerate(st.session_state.history):
1009
  with st.expander(f"Run {i+1} — {entry['timestamp']}"):
1010
  st.write(f"**Topics:** {entry['num_topics']}")
1011
+ st.write(
1012
+ f"**Outliers:** {entry.get('outlier_pct', entry.get('outlier_perc', 'N/A'))}"
1013
+ )
1014
  st.write("**Topic Labels:**")
1015
+ st.write(entry["llm_labels"])
1016
+ with st.expander("Show full configuration"):
1017
+ st.json(entry["config"])