Sai809701 commited on
Commit
ce552a1
Β·
1 Parent(s): 3b2edfc
Files changed (4) hide show
  1. Dockerfile +25 -0
  2. embed_build.py +114 -0
  3. main.py +221 -0
  4. requirements.txt +7 -0
Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile.muril
2
+ FROM python:3.11-slim
3
+
4
+ ENV DEBIAN_FRONTEND=noninteractive
5
+ RUN apt-get update && apt-get install -y --no-install-recommends \
6
+ build-essential git curl ca-certificates \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ WORKDIR /app
10
+
11
+ COPY requirements_api.txt /app/requirements.txt
12
+ RUN pip install --no-cache-dir -r /app/requirements.txt
13
+
14
+ # Copy app
15
+ COPY main.py /app/main.py
16
+
17
+ ENV HF_HOME=/app/hf_cache
18
+ ENV TRANSFORMERS_CACHE=/app/hf_cache
19
+ ENV TORCH_DISABLE_CUDA=1
20
+ ENV OUT_DIR=/app/export_artifacts
21
+ ENV MODEL_DIR=/app/muril_multilang_out
22
+ ENV PORT=7860
23
+
24
+ EXPOSE 7860
25
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
embed_build.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # embed_build_muril.py
2
+ """
3
+ Produce answer embeddings for the dataset using a fine-tuned MuRIL model.
4
+ Saves:
5
+ - muril_multilingual_dataset.csv (columns: question, answer, language)
6
+ - answer_embeddings.pt (torch tensor shape [N, D], float32, on CPU)
7
+
8
+ Usage:
9
+ python embed_build_muril.py \
10
+ --model_dir ./muril_multilang_out \
11
+ --input_jsonl /path/to/legal_multilingual_QA_10k.jsonl \
12
+ --out_dir ./export_artifacts \
13
+ --batch_size 64
14
+ """
15
+ import argparse, os, math
16
+ from pathlib import Path
17
+
18
+ import torch
19
+ import pandas as pd
20
+ from tqdm.auto import tqdm
21
+ from transformers import AutoTokenizer, AutoModel
22
+
23
+ def parse_args():
24
+ p = argparse.ArgumentParser()
25
+ p.add_argument("--model_dir", type=str, default="./muril_multilang_out", help="Path or HF repo id of fine-tuned MuRIL")
26
+ p.add_argument("--input_jsonl", type=str, required=True, help="Path to legal_multilingual_QA_10k.jsonl")
27
+ p.add_argument("--out_dir", type=str, default="./export_artifacts")
28
+ p.add_argument("--langs", type=str, default="en,hi,mr,ta,bn,gu,kn,ml,pa,or,as,ur,sa,ne", help="comma-separated languages to merge (will stack)")
29
+ p.add_argument("--text_prefix", type=str, default="question_", help="prefix for question columns in JSONL")
30
+ p.add_argument("--answer_col_prefix", type=str, default="answer_", help="prefix for answer columns if present (not used here)")
31
+ p.add_argument("--batch_size", type=int, default=64)
32
+ p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
33
+ return p.parse_args()
34
+
35
+ def mean_pooling(last_hidden_state, attention_mask):
36
+ # last_hidden_state: (B, L, H)
37
+ # attention_mask: (B, L)
38
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
39
+ sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
40
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
41
+ return sum_embeddings / sum_mask
42
+
43
+ def build_question_answer_rows(df, langs, text_prefix):
44
+ rows = []
45
+ for _, r in df.iterrows():
46
+ # merge all available language question/answer pairs by stacking
47
+ for lang in langs:
48
+ qcol = f"{text_prefix}{lang}"
49
+ acol = f"answer_{lang}"
50
+ # If dataset uses question_<lang> and answer_<lang>, use them; otherwise fall back to question_<lang> and common 'answer' field.
51
+ q = r.get(qcol, None)
52
+ if q is None or str(q).strip() == "" or str(q).lower() == "nan":
53
+ continue
54
+ # pick answer_<lang> if present else "answer" column
55
+ if acol in df.columns and pd.notna(r.get(acol)):
56
+ a = r.get(acol)
57
+ else:
58
+ a = r.get("answer", None)
59
+ if a is None or str(a).strip() == "" or str(a).lower() == "nan":
60
+ continue
61
+ rows.append({"question": str(q).strip(), "answer": str(a).strip(), "language": lang})
62
+ return pd.DataFrame(rows)
63
+
64
+ def main():
65
+ args = parse_args()
66
+ os.makedirs(args.out_dir, exist_ok=True)
67
+ # load JSONL to pandas
68
+ print("Loading dataset:", args.input_jsonl)
69
+ df_in = pd.read_json(args.input_jsonl, lines=True, dtype=str)
70
+ # Build rows stacked across languages (question_<lang>, answer optional)
71
+ langs = [l.strip() for l in args.langs.split(",") if l.strip()]
72
+ print("Merging language columns (stack)... langs:", langs)
73
+ rows_df = build_question_answer_rows(df_in, langs, args.text_prefix)
74
+ if rows_df.empty:
75
+ raise SystemExit("No question/answer rows found after merging languages. Check your columns.")
76
+ print(f"Total rows extracted: {len(rows_df)}")
77
+ # Save CSV (order matters)
78
+ csv_path = Path(args.out_dir) / "muril_multilingual_dataset.csv"
79
+ rows_df.to_csv(csv_path, index=False, encoding="utf-8")
80
+ print("Saved merged CSV to:", csv_path)
81
+
82
+ # Load model & tokenizer
83
+ print("Loading tokenizer & model from:", args.model_dir, "device:", args.device)
84
+ tokenizer = AutoTokenizer.from_pretrained(args.model_dir, use_fast=True)
85
+ model = AutoModel.from_pretrained(args.model_dir)
86
+ model.to(args.device)
87
+ model.eval()
88
+
89
+ # Encode answers in batches
90
+ answers = rows_df["answer"].astype(str).tolist()
91
+ batch_size = int(args.batch_size)
92
+ all_embs = []
93
+ with torch.inference_mode():
94
+ for i in tqdm(range(0, len(answers), batch_size), desc="Encoding"):
95
+ batch_texts = answers[i:i+batch_size]
96
+ encoded = tokenizer(batch_texts, padding=True, truncation=True, max_length=256, return_tensors="pt")
97
+ input_ids = encoded["input_ids"].to(args.device)
98
+ attention_mask = encoded["attention_mask"].to(args.device)
99
+ out = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
100
+ last_hidden = out.last_hidden_state # (B, L, H)
101
+ pooled = mean_pooling(last_hidden, attention_mask) # (B, H)
102
+ # L2-normalize embeddings (optional but recommended for cosine similarity)
103
+ pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)
104
+ all_embs.append(pooled.cpu())
105
+ all_embs = torch.cat(all_embs, dim=0) # (N, H)
106
+ print("Embeddings shape:", all_embs.shape)
107
+ embed_path = Path(args.out_dir) / "answer_embeddings.pt"
108
+ torch.save(all_embs, embed_path)
109
+ print("Saved embeddings to:", embed_path)
110
+
111
+ print("Done. Artifacts in:", args.out_dir)
112
+
113
+ if __name__ == "__main__":
114
+ main()
main.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+ import os
3
+ import sys
4
+ import torch
5
+ import pandas as pd
6
+ from fastapi import FastAPI, HTTPException
7
+ from pydantic import BaseModel
8
+ from typing import List, Optional
9
+ from huggingface_hub import snapshot_download
10
+ from transformers import AutoTokenizer, AutoModel
11
+ import numpy as np
12
+
13
+ # -----------------------
14
+ # Configuration (env)
15
+ # -----------------------
16
+ # Repo containing artifacts (CSV, embeddings, optionally model weights)
17
+ HF_REPO = os.getenv(
18
+ "HF_REPO",
19
+ "Sp2503/Finetuned-multilingualdataset-MuriL-model"
20
+ )
21
+
22
+ # Where snapshot_download will cache/put the files
23
+ CACHE_DIR = os.getenv("HF_CACHE_DIR", "/tmp/hf_artifacts")
24
+
25
+ # Names expected inside the repo
26
+ CSV_FILENAME = os.getenv("CSV_FILENAME", "muril_multilingual_dataset.csv")
27
+ EMBED_FILENAME = os.getenv("EMBED_FILENAME", "answer_embeddings.pt")
28
+ # If you stored a fine-tuned model in the same repo, set this to repo path (optional)
29
+ MODEL_SUBDIR = os.getenv("MODEL_SUBDIR", "") # leave empty if model is in root
30
+
31
+ # retrieval config
32
+ TOP_K = int(os.getenv("TOP_K", "1"))
33
+ DEVICE = "cuda" if torch.cuda.is_available() and os.getenv("TORCH_DISABLE_CUDA", "0") != "1" else "cpu"
34
+
35
+ # -----------------------
36
+ # Utility helpers
37
+ # -----------------------
38
+ def download_and_verify(repo_id: str, repo_type: str = "model", cache_dir: str = CACHE_DIR):
39
+ """
40
+ Download repo snapshot and verify that CSV + embeddings exist.
41
+ Returns absolute paths: model_dir, csv_path, embed_path
42
+ """
43
+ print(f"πŸ” snapshot_download: repo_id={repo_id} cache_dir={cache_dir}")
44
+ try:
45
+ model_dir = snapshot_download(repo_id=repo_id, repo_type=repo_type, cache_dir=cache_dir)
46
+ except Exception as e:
47
+ raise RuntimeError(
48
+ f"Failed to snapshot_download repo {repo_id}. "
49
+ "Make sure HUGGINGFACE_HUB_TOKEN (or HF_TOKEN) is set for private repos and that repo exists."
50
+ ) from e
51
+
52
+ csv_path = os.path.join(model_dir, CSV_FILENAME)
53
+ embed_path = os.path.join(model_dir, EMBED_FILENAME)
54
+ model_path = os.path.join(model_dir, MODEL_SUBDIR) if MODEL_SUBDIR else model_dir
55
+
56
+ # checks
57
+ missing = []
58
+ if not os.path.isfile(csv_path):
59
+ missing.append(csv_path)
60
+ if not os.path.isfile(embed_path):
61
+ missing.append(embed_path)
62
+
63
+ if missing:
64
+ raise FileNotFoundError(
65
+ f"Missing artifact(s) in the downloaded repo: {missing}\n"
66
+ f"Push {CSV_FILENAME} and {EMBED_FILENAME} to the repo '{repo_id}' or set HF_REPO to the correct one."
67
+ )
68
+
69
+ print(f"βœ… Downloaded snapshot to: {model_dir}")
70
+ print(f"βœ… Found CSV at: {csv_path}")
71
+ print(f"βœ… Found embeddings at: {embed_path}")
72
+ return model_path, csv_path, embed_path
73
+
74
+ def load_embeddings(emb_path: str, csv_len_expected: Optional[int] = None):
75
+ emb = torch.load(emb_path, map_location="cpu")
76
+ if not isinstance(emb, torch.Tensor):
77
+ raise ValueError(f"Embeddings file {emb_path} did not load a torch.Tensor (type={type(emb)}).")
78
+ if emb.ndim != 2:
79
+ raise ValueError(f"Embeddings tensor must have shape [N, D]. Got {tuple(emb.shape)}.")
80
+ if csv_len_expected is not None and emb.shape[0] != csv_len_expected:
81
+ raise ValueError(f"Mismatch: CSV rows={csv_len_expected} but embeddings rows={emb.shape[0]}. Ensure ordering matches.")
82
+ return emb
83
+
84
+ def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor):
85
+ mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
86
+ summed = torch.sum(last_hidden_state * mask, dim=1)
87
+ denom = torch.clamp(mask.sum(dim=1), min=1e-9)
88
+ return summed / denom
89
+
90
+ # -----------------------
91
+ # Startup: download artifacts
92
+ # -----------------------
93
+ try:
94
+ MODEL_DIR, CSV_PATH, EMBED_PATH = download_and_verify(HF_REPO, repo_type="model", cache_dir=CACHE_DIR)
95
+ except Exception as e:
96
+ print("βœ– ERROR during snapshot_download or verification:", e, file=sys.stderr)
97
+ raise
98
+
99
+ # -----------------------
100
+ # Load CSV and embeddings
101
+ # -----------------------
102
+ print("πŸ“₯ Loading CSV:", CSV_PATH)
103
+ df = pd.read_csv(CSV_PATH, dtype=str).fillna("")
104
+ # Ensure required columns
105
+ if not {"question", "answer"}.issubset(set(df.columns)):
106
+ raise RuntimeError(f"CSV must contain 'question' and 'answer' columns. Found: {df.columns.tolist()}")
107
+
108
+ # Default language column if missing
109
+ if "language" not in df.columns:
110
+ df["language"] = "en"
111
+
112
+ print("πŸ“₯ Loading embeddings (this may take a second)...")
113
+ answer_embeddings = load_embeddings(EMBED_PATH, csv_len_expected=len(df))
114
+ # Ensure embeddings are float32 and normalized (we normalize incoming queries; if embeddings not normalized, we normalize here)
115
+ if answer_embeddings.dtype != torch.float32:
116
+ answer_embeddings = answer_embeddings.to(torch.float32)
117
+ # Normalize stored embeddings for dot-product = cosine
118
+ answer_embeddings = torch.nn.functional.normalize(answer_embeddings, p=2, dim=1)
119
+
120
+ print("βœ… Loaded dataset rows:", len(df), "embedding dim:", answer_embeddings.shape[1])
121
+
122
+ # -----------------------
123
+ # Load MuRIL model for query encoding (AutoModel + tokenizer)
124
+ # -----------------------
125
+ print("βš™οΈ Loading tokenizer & model for text encoding from:", MODEL_DIR)
126
+ try:
127
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True)
128
+ encoder_model = AutoModel.from_pretrained(MODEL_DIR)
129
+ encoder_model.to(DEVICE)
130
+ encoder_model.eval()
131
+ except Exception as e:
132
+ # fail fast with helpful message
133
+ raise RuntimeError(
134
+ f"Failed to load model/tokenizer from {MODEL_DIR}. "
135
+ "If the repo stores only adapters/LoRA, you must load base model + apply adapters. "
136
+ "Ensure the repo contains full model files or set MODEL_SUBDIR appropriately."
137
+ ) from e
138
+
139
+ def encode_query(texts: List[str], batch_size: int = 32):
140
+ all_embs = []
141
+ with torch.inference_mode():
142
+ for i in range(0, len(texts), batch_size):
143
+ batch = texts[i : i + batch_size]
144
+ enc = tokenizer(batch, padding=True, truncation=True, max_length=256, return_tensors="pt")
145
+ input_ids = enc["input_ids"].to(DEVICE)
146
+ attention_mask = enc["attention_mask"].to(DEVICE)
147
+ out = encoder_model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
148
+ pooled = mean_pool(out.last_hidden_state, attention_mask)
149
+ pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)
150
+ all_embs.append(pooled.cpu())
151
+ return torch.cat(all_embs, dim=0)
152
+
153
+ # -----------------------
154
+ # FastAPI app
155
+ # -----------------------
156
+ app = FastAPI(title="MuRIL Multilingual QA API (Hub-backed artifacts)")
157
+
158
+ class QueryRequest(BaseModel):
159
+ question: str
160
+ lang: Optional[str] = None
161
+
162
+ class QAResponse(BaseModel):
163
+ answer: str
164
+ detected_lang: str
165
+ top_k: Optional[List[dict]] = None
166
+
167
+ @app.get("/")
168
+ def root():
169
+ langs = sorted(df["language"].unique().tolist())
170
+ return {
171
+ "status": "βœ… MuRIL Multilingual QA API Running",
172
+ "available_languages": langs,
173
+ "model_repo": HF_REPO,
174
+ "loaded_rows": len(df),
175
+ }
176
+
177
+ @app.post("/get-answer", response_model=QAResponse)
178
+ def get_answer_endpoint(req: QueryRequest):
179
+ qtext = (req.question or "").strip()
180
+ if not qtext:
181
+ raise HTTPException(status_code=400, detail="Empty question")
182
+
183
+ lang_filter = (req.lang or "").strip()
184
+ filtered_df = df
185
+ filtered_embeddings = answer_embeddings
186
+ if lang_filter:
187
+ mask = df["language"] == lang_filter
188
+ if not mask.any():
189
+ # no data for this language
190
+ return QAResponse(answer=f"⚠️ No data found for language '{lang_filter}'.", detected_lang=lang_filter)
191
+ filtered_df = df[mask].reset_index(drop=True)
192
+ filtered_embeddings = answer_embeddings[mask.values]
193
+
194
+ # encode query
195
+ q_emb = encode_query([qtext], batch_size=1) # shape (1, D)
196
+ sims = torch.matmul(q_emb, filtered_embeddings.T).squeeze(0) # (N,)
197
+ k = max(1, min(TOP_K, len(filtered_df)))
198
+ topv, topi = torch.topk(sims, k=k)
199
+ best_idx = int(topi[0].item())
200
+ answer = filtered_df.iloc[best_idx]["answer"]
201
+
202
+ top_k_payload = None
203
+ if k > 1:
204
+ top_k_payload = []
205
+ for rank in range(k):
206
+ idx = int(topi[rank].item())
207
+ top_k_payload.append({
208
+ "rank": rank + 1,
209
+ "score": float(topv[rank].item()),
210
+ "answer": filtered_df.iloc[idx]["answer"],
211
+ "question": filtered_df.iloc[idx]["question"],
212
+ })
213
+
214
+ return QAResponse(answer=answer, detected_lang=lang_filter or "all", top_k=top_k_payload)
215
+
216
+ # -----------------------
217
+ # Run server (if invoked directly)
218
+ # -----------------------
219
+ if __name__ == "__main__":
220
+ import uvicorn
221
+ uvicorn.run("main:app", host="0.0.0.0", port=int(os.getenv("PORT", "7860")), workers=1)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi==0.115.5
2
+ uvicorn[standard]==0.32.0
3
+ pandas==2.2.3
4
+ torch>=2.1.0
5
+ transformers==4.46.0
6
+ huggingface_hub>=0.14.1
7
+ tqdm