MuriL-2.0 / embed_build.py
Sai809701
muril api
ce552a1
raw
history blame
5.43 kB
# embed_build_muril.py
"""
Produce answer embeddings for the dataset using a fine-tuned MuRIL model.
Saves:
- muril_multilingual_dataset.csv (columns: question, answer, language)
- answer_embeddings.pt (torch tensor shape [N, D], float32, on CPU)
Usage:
python embed_build_muril.py \
--model_dir ./muril_multilang_out \
--input_jsonl /path/to/legal_multilingual_QA_10k.jsonl \
--out_dir ./export_artifacts \
--batch_size 64
"""
import argparse, os, math
from pathlib import Path
import torch
import pandas as pd
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModel
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--model_dir", type=str, default="./muril_multilang_out", help="Path or HF repo id of fine-tuned MuRIL")
p.add_argument("--input_jsonl", type=str, required=True, help="Path to legal_multilingual_QA_10k.jsonl")
p.add_argument("--out_dir", type=str, default="./export_artifacts")
p.add_argument("--langs", type=str, default="en,hi,mr,ta,bn,gu,kn,ml,pa,or,as,ur,sa,ne", help="comma-separated languages to merge (will stack)")
p.add_argument("--text_prefix", type=str, default="question_", help="prefix for question columns in JSONL")
p.add_argument("--answer_col_prefix", type=str, default="answer_", help="prefix for answer columns if present (not used here)")
p.add_argument("--batch_size", type=int, default=64)
p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
return p.parse_args()
def mean_pooling(last_hidden_state, attention_mask):
# last_hidden_state: (B, L, H)
# attention_mask: (B, L)
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask
def build_question_answer_rows(df, langs, text_prefix):
rows = []
for _, r in df.iterrows():
# merge all available language question/answer pairs by stacking
for lang in langs:
qcol = f"{text_prefix}{lang}"
acol = f"answer_{lang}"
# If dataset uses question_<lang> and answer_<lang>, use them; otherwise fall back to question_<lang> and common 'answer' field.
q = r.get(qcol, None)
if q is None or str(q).strip() == "" or str(q).lower() == "nan":
continue
# pick answer_<lang> if present else "answer" column
if acol in df.columns and pd.notna(r.get(acol)):
a = r.get(acol)
else:
a = r.get("answer", None)
if a is None or str(a).strip() == "" or str(a).lower() == "nan":
continue
rows.append({"question": str(q).strip(), "answer": str(a).strip(), "language": lang})
return pd.DataFrame(rows)
def main():
args = parse_args()
os.makedirs(args.out_dir, exist_ok=True)
# load JSONL to pandas
print("Loading dataset:", args.input_jsonl)
df_in = pd.read_json(args.input_jsonl, lines=True, dtype=str)
# Build rows stacked across languages (question_<lang>, answer optional)
langs = [l.strip() for l in args.langs.split(",") if l.strip()]
print("Merging language columns (stack)... langs:", langs)
rows_df = build_question_answer_rows(df_in, langs, args.text_prefix)
if rows_df.empty:
raise SystemExit("No question/answer rows found after merging languages. Check your columns.")
print(f"Total rows extracted: {len(rows_df)}")
# Save CSV (order matters)
csv_path = Path(args.out_dir) / "muril_multilingual_dataset.csv"
rows_df.to_csv(csv_path, index=False, encoding="utf-8")
print("Saved merged CSV to:", csv_path)
# Load model & tokenizer
print("Loading tokenizer & model from:", args.model_dir, "device:", args.device)
tokenizer = AutoTokenizer.from_pretrained(args.model_dir, use_fast=True)
model = AutoModel.from_pretrained(args.model_dir)
model.to(args.device)
model.eval()
# Encode answers in batches
answers = rows_df["answer"].astype(str).tolist()
batch_size = int(args.batch_size)
all_embs = []
with torch.inference_mode():
for i in tqdm(range(0, len(answers), batch_size), desc="Encoding"):
batch_texts = answers[i:i+batch_size]
encoded = tokenizer(batch_texts, padding=True, truncation=True, max_length=256, return_tensors="pt")
input_ids = encoded["input_ids"].to(args.device)
attention_mask = encoded["attention_mask"].to(args.device)
out = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
last_hidden = out.last_hidden_state # (B, L, H)
pooled = mean_pooling(last_hidden, attention_mask) # (B, H)
# L2-normalize embeddings (optional but recommended for cosine similarity)
pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)
all_embs.append(pooled.cpu())
all_embs = torch.cat(all_embs, dim=0) # (N, H)
print("Embeddings shape:", all_embs.shape)
embed_path = Path(args.out_dir) / "answer_embeddings.pt"
torch.save(all_embs, embed_path)
print("Saved embeddings to:", embed_path)
print("Done. Artifacts in:", args.out_dir)
if __name__ == "__main__":
main()