|
|
|
|
|
""" |
|
|
Produce answer embeddings for the dataset using a fine-tuned MuRIL model. |
|
|
Saves: |
|
|
- muril_multilingual_dataset.csv (columns: question, answer, language) |
|
|
- answer_embeddings.pt (torch tensor shape [N, D], float32, on CPU) |
|
|
|
|
|
Usage: |
|
|
python embed_build_muril.py \ |
|
|
--model_dir ./muril_multilang_out \ |
|
|
--input_jsonl /path/to/legal_multilingual_QA_10k.jsonl \ |
|
|
--out_dir ./export_artifacts \ |
|
|
--batch_size 64 |
|
|
""" |
|
|
import argparse, os, math |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
import pandas as pd |
|
|
from tqdm.auto import tqdm |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
|
|
def parse_args(): |
|
|
p = argparse.ArgumentParser() |
|
|
p.add_argument("--model_dir", type=str, default="./muril_multilang_out", help="Path or HF repo id of fine-tuned MuRIL") |
|
|
p.add_argument("--input_jsonl", type=str, required=True, help="Path to legal_multilingual_QA_10k.jsonl") |
|
|
p.add_argument("--out_dir", type=str, default="./export_artifacts") |
|
|
p.add_argument("--langs", type=str, default="en,hi,mr,ta,bn,gu,kn,ml,pa,or,as,ur,sa,ne", help="comma-separated languages to merge (will stack)") |
|
|
p.add_argument("--text_prefix", type=str, default="question_", help="prefix for question columns in JSONL") |
|
|
p.add_argument("--answer_col_prefix", type=str, default="answer_", help="prefix for answer columns if present (not used here)") |
|
|
p.add_argument("--batch_size", type=int, default=64) |
|
|
p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") |
|
|
return p.parse_args() |
|
|
|
|
|
def mean_pooling(last_hidden_state, attention_mask): |
|
|
|
|
|
|
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() |
|
|
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1) |
|
|
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
return sum_embeddings / sum_mask |
|
|
|
|
|
def build_question_answer_rows(df, langs, text_prefix): |
|
|
rows = [] |
|
|
for _, r in df.iterrows(): |
|
|
|
|
|
for lang in langs: |
|
|
qcol = f"{text_prefix}{lang}" |
|
|
acol = f"answer_{lang}" |
|
|
|
|
|
q = r.get(qcol, None) |
|
|
if q is None or str(q).strip() == "" or str(q).lower() == "nan": |
|
|
continue |
|
|
|
|
|
if acol in df.columns and pd.notna(r.get(acol)): |
|
|
a = r.get(acol) |
|
|
else: |
|
|
a = r.get("answer", None) |
|
|
if a is None or str(a).strip() == "" or str(a).lower() == "nan": |
|
|
continue |
|
|
rows.append({"question": str(q).strip(), "answer": str(a).strip(), "language": lang}) |
|
|
return pd.DataFrame(rows) |
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
os.makedirs(args.out_dir, exist_ok=True) |
|
|
|
|
|
print("Loading dataset:", args.input_jsonl) |
|
|
df_in = pd.read_json(args.input_jsonl, lines=True, dtype=str) |
|
|
|
|
|
langs = [l.strip() for l in args.langs.split(",") if l.strip()] |
|
|
print("Merging language columns (stack)... langs:", langs) |
|
|
rows_df = build_question_answer_rows(df_in, langs, args.text_prefix) |
|
|
if rows_df.empty: |
|
|
raise SystemExit("No question/answer rows found after merging languages. Check your columns.") |
|
|
print(f"Total rows extracted: {len(rows_df)}") |
|
|
|
|
|
csv_path = Path(args.out_dir) / "muril_multilingual_dataset.csv" |
|
|
rows_df.to_csv(csv_path, index=False, encoding="utf-8") |
|
|
print("Saved merged CSV to:", csv_path) |
|
|
|
|
|
|
|
|
print("Loading tokenizer & model from:", args.model_dir, "device:", args.device) |
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_dir, use_fast=True) |
|
|
model = AutoModel.from_pretrained(args.model_dir) |
|
|
model.to(args.device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
answers = rows_df["answer"].astype(str).tolist() |
|
|
batch_size = int(args.batch_size) |
|
|
all_embs = [] |
|
|
with torch.inference_mode(): |
|
|
for i in tqdm(range(0, len(answers), batch_size), desc="Encoding"): |
|
|
batch_texts = answers[i:i+batch_size] |
|
|
encoded = tokenizer(batch_texts, padding=True, truncation=True, max_length=256, return_tensors="pt") |
|
|
input_ids = encoded["input_ids"].to(args.device) |
|
|
attention_mask = encoded["attention_mask"].to(args.device) |
|
|
out = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) |
|
|
last_hidden = out.last_hidden_state |
|
|
pooled = mean_pooling(last_hidden, attention_mask) |
|
|
|
|
|
pooled = torch.nn.functional.normalize(pooled, p=2, dim=1) |
|
|
all_embs.append(pooled.cpu()) |
|
|
all_embs = torch.cat(all_embs, dim=0) |
|
|
print("Embeddings shape:", all_embs.shape) |
|
|
embed_path = Path(args.out_dir) / "answer_embeddings.pt" |
|
|
torch.save(all_embs, embed_path) |
|
|
print("Saved embeddings to:", embed_path) |
|
|
|
|
|
print("Done. Artifacts in:", args.out_dir) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|