File size: 5,426 Bytes
ce552a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# embed_build_muril.py
"""
Produce answer embeddings for the dataset using a fine-tuned MuRIL model.
Saves:
 - muril_multilingual_dataset.csv  (columns: question, answer, language)
 - answer_embeddings.pt            (torch tensor shape [N, D], float32, on CPU)

Usage:
 python embed_build_muril.py \
    --model_dir ./muril_multilang_out \
    --input_jsonl /path/to/legal_multilingual_QA_10k.jsonl \
    --out_dir ./export_artifacts \
    --batch_size 64
"""
import argparse, os, math
from pathlib import Path

import torch
import pandas as pd
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModel

def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--model_dir", type=str, default="./muril_multilang_out", help="Path or HF repo id of fine-tuned MuRIL")
    p.add_argument("--input_jsonl", type=str, required=True, help="Path to legal_multilingual_QA_10k.jsonl")
    p.add_argument("--out_dir", type=str, default="./export_artifacts")
    p.add_argument("--langs", type=str, default="en,hi,mr,ta,bn,gu,kn,ml,pa,or,as,ur,sa,ne", help="comma-separated languages to merge (will stack)")
    p.add_argument("--text_prefix", type=str, default="question_", help="prefix for question columns in JSONL")
    p.add_argument("--answer_col_prefix", type=str, default="answer_", help="prefix for answer columns if present (not used here)")
    p.add_argument("--batch_size", type=int, default=64)
    p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    return p.parse_args()

def mean_pooling(last_hidden_state, attention_mask):
    # last_hidden_state: (B, L, H)
    # attention_mask: (B, L)
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
    sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

def build_question_answer_rows(df, langs, text_prefix):
    rows = []
    for _, r in df.iterrows():
        # merge all available language question/answer pairs by stacking
        for lang in langs:
            qcol = f"{text_prefix}{lang}"
            acol = f"answer_{lang}"
            # If dataset uses question_<lang> and answer_<lang>, use them; otherwise fall back to question_<lang> and common 'answer' field.
            q = r.get(qcol, None)
            if q is None or str(q).strip() == "" or str(q).lower() == "nan":
                continue
            # pick answer_<lang> if present else "answer" column
            if acol in df.columns and pd.notna(r.get(acol)):
                a = r.get(acol)
            else:
                a = r.get("answer", None)
            if a is None or str(a).strip() == "" or str(a).lower() == "nan":
                continue
            rows.append({"question": str(q).strip(), "answer": str(a).strip(), "language": lang})
    return pd.DataFrame(rows)

def main():
    args = parse_args()
    os.makedirs(args.out_dir, exist_ok=True)
    # load JSONL to pandas
    print("Loading dataset:", args.input_jsonl)
    df_in = pd.read_json(args.input_jsonl, lines=True, dtype=str)
    # Build rows stacked across languages (question_<lang>, answer optional)
    langs = [l.strip() for l in args.langs.split(",") if l.strip()]
    print("Merging language columns (stack)... langs:", langs)
    rows_df = build_question_answer_rows(df_in, langs, args.text_prefix)
    if rows_df.empty:
        raise SystemExit("No question/answer rows found after merging languages. Check your columns.")
    print(f"Total rows extracted: {len(rows_df)}")
    # Save CSV (order matters)
    csv_path = Path(args.out_dir) / "muril_multilingual_dataset.csv"
    rows_df.to_csv(csv_path, index=False, encoding="utf-8")
    print("Saved merged CSV to:", csv_path)

    # Load model & tokenizer
    print("Loading tokenizer & model from:", args.model_dir, "device:", args.device)
    tokenizer = AutoTokenizer.from_pretrained(args.model_dir, use_fast=True)
    model = AutoModel.from_pretrained(args.model_dir)
    model.to(args.device)
    model.eval()

    # Encode answers in batches
    answers = rows_df["answer"].astype(str).tolist()
    batch_size = int(args.batch_size)
    all_embs = []
    with torch.inference_mode():
        for i in tqdm(range(0, len(answers), batch_size), desc="Encoding"):
            batch_texts = answers[i:i+batch_size]
            encoded = tokenizer(batch_texts, padding=True, truncation=True, max_length=256, return_tensors="pt")
            input_ids = encoded["input_ids"].to(args.device)
            attention_mask = encoded["attention_mask"].to(args.device)
            out = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
            last_hidden = out.last_hidden_state  # (B, L, H)
            pooled = mean_pooling(last_hidden, attention_mask)  # (B, H)
            # L2-normalize embeddings (optional but recommended for cosine similarity)
            pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)
            all_embs.append(pooled.cpu())
    all_embs = torch.cat(all_embs, dim=0)  # (N, H)
    print("Embeddings shape:", all_embs.shape)
    embed_path = Path(args.out_dir) / "answer_embeddings.pt"
    torch.save(all_embs, embed_path)
    print("Saved embeddings to:", embed_path)

    print("Done. Artifacts in:", args.out_dir)

if __name__ == "__main__":
    main()