Sai809701
commited on
Commit
Β·
ce552a1
1
Parent(s):
3b2edfc
muril api
Browse files- Dockerfile +25 -0
- embed_build.py +114 -0
- main.py +221 -0
- requirements.txt +7 -0
Dockerfile
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dockerfile.muril
|
| 2 |
+
FROM python:3.11-slim
|
| 3 |
+
|
| 4 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 5 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 6 |
+
build-essential git curl ca-certificates \
|
| 7 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 8 |
+
|
| 9 |
+
WORKDIR /app
|
| 10 |
+
|
| 11 |
+
COPY requirements_api.txt /app/requirements.txt
|
| 12 |
+
RUN pip install --no-cache-dir -r /app/requirements.txt
|
| 13 |
+
|
| 14 |
+
# Copy app
|
| 15 |
+
COPY main.py /app/main.py
|
| 16 |
+
|
| 17 |
+
ENV HF_HOME=/app/hf_cache
|
| 18 |
+
ENV TRANSFORMERS_CACHE=/app/hf_cache
|
| 19 |
+
ENV TORCH_DISABLE_CUDA=1
|
| 20 |
+
ENV OUT_DIR=/app/export_artifacts
|
| 21 |
+
ENV MODEL_DIR=/app/muril_multilang_out
|
| 22 |
+
ENV PORT=7860
|
| 23 |
+
|
| 24 |
+
EXPOSE 7860
|
| 25 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
embed_build.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# embed_build_muril.py
|
| 2 |
+
"""
|
| 3 |
+
Produce answer embeddings for the dataset using a fine-tuned MuRIL model.
|
| 4 |
+
Saves:
|
| 5 |
+
- muril_multilingual_dataset.csv (columns: question, answer, language)
|
| 6 |
+
- answer_embeddings.pt (torch tensor shape [N, D], float32, on CPU)
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python embed_build_muril.py \
|
| 10 |
+
--model_dir ./muril_multilang_out \
|
| 11 |
+
--input_jsonl /path/to/legal_multilingual_QA_10k.jsonl \
|
| 12 |
+
--out_dir ./export_artifacts \
|
| 13 |
+
--batch_size 64
|
| 14 |
+
"""
|
| 15 |
+
import argparse, os, math
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import pandas as pd
|
| 20 |
+
from tqdm.auto import tqdm
|
| 21 |
+
from transformers import AutoTokenizer, AutoModel
|
| 22 |
+
|
| 23 |
+
def parse_args():
|
| 24 |
+
p = argparse.ArgumentParser()
|
| 25 |
+
p.add_argument("--model_dir", type=str, default="./muril_multilang_out", help="Path or HF repo id of fine-tuned MuRIL")
|
| 26 |
+
p.add_argument("--input_jsonl", type=str, required=True, help="Path to legal_multilingual_QA_10k.jsonl")
|
| 27 |
+
p.add_argument("--out_dir", type=str, default="./export_artifacts")
|
| 28 |
+
p.add_argument("--langs", type=str, default="en,hi,mr,ta,bn,gu,kn,ml,pa,or,as,ur,sa,ne", help="comma-separated languages to merge (will stack)")
|
| 29 |
+
p.add_argument("--text_prefix", type=str, default="question_", help="prefix for question columns in JSONL")
|
| 30 |
+
p.add_argument("--answer_col_prefix", type=str, default="answer_", help="prefix for answer columns if present (not used here)")
|
| 31 |
+
p.add_argument("--batch_size", type=int, default=64)
|
| 32 |
+
p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
| 33 |
+
return p.parse_args()
|
| 34 |
+
|
| 35 |
+
def mean_pooling(last_hidden_state, attention_mask):
|
| 36 |
+
# last_hidden_state: (B, L, H)
|
| 37 |
+
# attention_mask: (B, L)
|
| 38 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
|
| 39 |
+
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
|
| 40 |
+
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
| 41 |
+
return sum_embeddings / sum_mask
|
| 42 |
+
|
| 43 |
+
def build_question_answer_rows(df, langs, text_prefix):
|
| 44 |
+
rows = []
|
| 45 |
+
for _, r in df.iterrows():
|
| 46 |
+
# merge all available language question/answer pairs by stacking
|
| 47 |
+
for lang in langs:
|
| 48 |
+
qcol = f"{text_prefix}{lang}"
|
| 49 |
+
acol = f"answer_{lang}"
|
| 50 |
+
# If dataset uses question_<lang> and answer_<lang>, use them; otherwise fall back to question_<lang> and common 'answer' field.
|
| 51 |
+
q = r.get(qcol, None)
|
| 52 |
+
if q is None or str(q).strip() == "" or str(q).lower() == "nan":
|
| 53 |
+
continue
|
| 54 |
+
# pick answer_<lang> if present else "answer" column
|
| 55 |
+
if acol in df.columns and pd.notna(r.get(acol)):
|
| 56 |
+
a = r.get(acol)
|
| 57 |
+
else:
|
| 58 |
+
a = r.get("answer", None)
|
| 59 |
+
if a is None or str(a).strip() == "" or str(a).lower() == "nan":
|
| 60 |
+
continue
|
| 61 |
+
rows.append({"question": str(q).strip(), "answer": str(a).strip(), "language": lang})
|
| 62 |
+
return pd.DataFrame(rows)
|
| 63 |
+
|
| 64 |
+
def main():
|
| 65 |
+
args = parse_args()
|
| 66 |
+
os.makedirs(args.out_dir, exist_ok=True)
|
| 67 |
+
# load JSONL to pandas
|
| 68 |
+
print("Loading dataset:", args.input_jsonl)
|
| 69 |
+
df_in = pd.read_json(args.input_jsonl, lines=True, dtype=str)
|
| 70 |
+
# Build rows stacked across languages (question_<lang>, answer optional)
|
| 71 |
+
langs = [l.strip() for l in args.langs.split(",") if l.strip()]
|
| 72 |
+
print("Merging language columns (stack)... langs:", langs)
|
| 73 |
+
rows_df = build_question_answer_rows(df_in, langs, args.text_prefix)
|
| 74 |
+
if rows_df.empty:
|
| 75 |
+
raise SystemExit("No question/answer rows found after merging languages. Check your columns.")
|
| 76 |
+
print(f"Total rows extracted: {len(rows_df)}")
|
| 77 |
+
# Save CSV (order matters)
|
| 78 |
+
csv_path = Path(args.out_dir) / "muril_multilingual_dataset.csv"
|
| 79 |
+
rows_df.to_csv(csv_path, index=False, encoding="utf-8")
|
| 80 |
+
print("Saved merged CSV to:", csv_path)
|
| 81 |
+
|
| 82 |
+
# Load model & tokenizer
|
| 83 |
+
print("Loading tokenizer & model from:", args.model_dir, "device:", args.device)
|
| 84 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_dir, use_fast=True)
|
| 85 |
+
model = AutoModel.from_pretrained(args.model_dir)
|
| 86 |
+
model.to(args.device)
|
| 87 |
+
model.eval()
|
| 88 |
+
|
| 89 |
+
# Encode answers in batches
|
| 90 |
+
answers = rows_df["answer"].astype(str).tolist()
|
| 91 |
+
batch_size = int(args.batch_size)
|
| 92 |
+
all_embs = []
|
| 93 |
+
with torch.inference_mode():
|
| 94 |
+
for i in tqdm(range(0, len(answers), batch_size), desc="Encoding"):
|
| 95 |
+
batch_texts = answers[i:i+batch_size]
|
| 96 |
+
encoded = tokenizer(batch_texts, padding=True, truncation=True, max_length=256, return_tensors="pt")
|
| 97 |
+
input_ids = encoded["input_ids"].to(args.device)
|
| 98 |
+
attention_mask = encoded["attention_mask"].to(args.device)
|
| 99 |
+
out = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
|
| 100 |
+
last_hidden = out.last_hidden_state # (B, L, H)
|
| 101 |
+
pooled = mean_pooling(last_hidden, attention_mask) # (B, H)
|
| 102 |
+
# L2-normalize embeddings (optional but recommended for cosine similarity)
|
| 103 |
+
pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)
|
| 104 |
+
all_embs.append(pooled.cpu())
|
| 105 |
+
all_embs = torch.cat(all_embs, dim=0) # (N, H)
|
| 106 |
+
print("Embeddings shape:", all_embs.shape)
|
| 107 |
+
embed_path = Path(args.out_dir) / "answer_embeddings.pt"
|
| 108 |
+
torch.save(all_embs, embed_path)
|
| 109 |
+
print("Saved embeddings to:", embed_path)
|
| 110 |
+
|
| 111 |
+
print("Done. Artifacts in:", args.out_dir)
|
| 112 |
+
|
| 113 |
+
if __name__ == "__main__":
|
| 114 |
+
main()
|
main.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# main.py
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import torch
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from fastapi import FastAPI, HTTPException
|
| 7 |
+
from pydantic import BaseModel
|
| 8 |
+
from typing import List, Optional
|
| 9 |
+
from huggingface_hub import snapshot_download
|
| 10 |
+
from transformers import AutoTokenizer, AutoModel
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
# -----------------------
|
| 14 |
+
# Configuration (env)
|
| 15 |
+
# -----------------------
|
| 16 |
+
# Repo containing artifacts (CSV, embeddings, optionally model weights)
|
| 17 |
+
HF_REPO = os.getenv(
|
| 18 |
+
"HF_REPO",
|
| 19 |
+
"Sp2503/Finetuned-multilingualdataset-MuriL-model"
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# Where snapshot_download will cache/put the files
|
| 23 |
+
CACHE_DIR = os.getenv("HF_CACHE_DIR", "/tmp/hf_artifacts")
|
| 24 |
+
|
| 25 |
+
# Names expected inside the repo
|
| 26 |
+
CSV_FILENAME = os.getenv("CSV_FILENAME", "muril_multilingual_dataset.csv")
|
| 27 |
+
EMBED_FILENAME = os.getenv("EMBED_FILENAME", "answer_embeddings.pt")
|
| 28 |
+
# If you stored a fine-tuned model in the same repo, set this to repo path (optional)
|
| 29 |
+
MODEL_SUBDIR = os.getenv("MODEL_SUBDIR", "") # leave empty if model is in root
|
| 30 |
+
|
| 31 |
+
# retrieval config
|
| 32 |
+
TOP_K = int(os.getenv("TOP_K", "1"))
|
| 33 |
+
DEVICE = "cuda" if torch.cuda.is_available() and os.getenv("TORCH_DISABLE_CUDA", "0") != "1" else "cpu"
|
| 34 |
+
|
| 35 |
+
# -----------------------
|
| 36 |
+
# Utility helpers
|
| 37 |
+
# -----------------------
|
| 38 |
+
def download_and_verify(repo_id: str, repo_type: str = "model", cache_dir: str = CACHE_DIR):
|
| 39 |
+
"""
|
| 40 |
+
Download repo snapshot and verify that CSV + embeddings exist.
|
| 41 |
+
Returns absolute paths: model_dir, csv_path, embed_path
|
| 42 |
+
"""
|
| 43 |
+
print(f"π snapshot_download: repo_id={repo_id} cache_dir={cache_dir}")
|
| 44 |
+
try:
|
| 45 |
+
model_dir = snapshot_download(repo_id=repo_id, repo_type=repo_type, cache_dir=cache_dir)
|
| 46 |
+
except Exception as e:
|
| 47 |
+
raise RuntimeError(
|
| 48 |
+
f"Failed to snapshot_download repo {repo_id}. "
|
| 49 |
+
"Make sure HUGGINGFACE_HUB_TOKEN (or HF_TOKEN) is set for private repos and that repo exists."
|
| 50 |
+
) from e
|
| 51 |
+
|
| 52 |
+
csv_path = os.path.join(model_dir, CSV_FILENAME)
|
| 53 |
+
embed_path = os.path.join(model_dir, EMBED_FILENAME)
|
| 54 |
+
model_path = os.path.join(model_dir, MODEL_SUBDIR) if MODEL_SUBDIR else model_dir
|
| 55 |
+
|
| 56 |
+
# checks
|
| 57 |
+
missing = []
|
| 58 |
+
if not os.path.isfile(csv_path):
|
| 59 |
+
missing.append(csv_path)
|
| 60 |
+
if not os.path.isfile(embed_path):
|
| 61 |
+
missing.append(embed_path)
|
| 62 |
+
|
| 63 |
+
if missing:
|
| 64 |
+
raise FileNotFoundError(
|
| 65 |
+
f"Missing artifact(s) in the downloaded repo: {missing}\n"
|
| 66 |
+
f"Push {CSV_FILENAME} and {EMBED_FILENAME} to the repo '{repo_id}' or set HF_REPO to the correct one."
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
print(f"β
Downloaded snapshot to: {model_dir}")
|
| 70 |
+
print(f"β
Found CSV at: {csv_path}")
|
| 71 |
+
print(f"β
Found embeddings at: {embed_path}")
|
| 72 |
+
return model_path, csv_path, embed_path
|
| 73 |
+
|
| 74 |
+
def load_embeddings(emb_path: str, csv_len_expected: Optional[int] = None):
|
| 75 |
+
emb = torch.load(emb_path, map_location="cpu")
|
| 76 |
+
if not isinstance(emb, torch.Tensor):
|
| 77 |
+
raise ValueError(f"Embeddings file {emb_path} did not load a torch.Tensor (type={type(emb)}).")
|
| 78 |
+
if emb.ndim != 2:
|
| 79 |
+
raise ValueError(f"Embeddings tensor must have shape [N, D]. Got {tuple(emb.shape)}.")
|
| 80 |
+
if csv_len_expected is not None and emb.shape[0] != csv_len_expected:
|
| 81 |
+
raise ValueError(f"Mismatch: CSV rows={csv_len_expected} but embeddings rows={emb.shape[0]}. Ensure ordering matches.")
|
| 82 |
+
return emb
|
| 83 |
+
|
| 84 |
+
def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor):
|
| 85 |
+
mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
|
| 86 |
+
summed = torch.sum(last_hidden_state * mask, dim=1)
|
| 87 |
+
denom = torch.clamp(mask.sum(dim=1), min=1e-9)
|
| 88 |
+
return summed / denom
|
| 89 |
+
|
| 90 |
+
# -----------------------
|
| 91 |
+
# Startup: download artifacts
|
| 92 |
+
# -----------------------
|
| 93 |
+
try:
|
| 94 |
+
MODEL_DIR, CSV_PATH, EMBED_PATH = download_and_verify(HF_REPO, repo_type="model", cache_dir=CACHE_DIR)
|
| 95 |
+
except Exception as e:
|
| 96 |
+
print("β ERROR during snapshot_download or verification:", e, file=sys.stderr)
|
| 97 |
+
raise
|
| 98 |
+
|
| 99 |
+
# -----------------------
|
| 100 |
+
# Load CSV and embeddings
|
| 101 |
+
# -----------------------
|
| 102 |
+
print("π₯ Loading CSV:", CSV_PATH)
|
| 103 |
+
df = pd.read_csv(CSV_PATH, dtype=str).fillna("")
|
| 104 |
+
# Ensure required columns
|
| 105 |
+
if not {"question", "answer"}.issubset(set(df.columns)):
|
| 106 |
+
raise RuntimeError(f"CSV must contain 'question' and 'answer' columns. Found: {df.columns.tolist()}")
|
| 107 |
+
|
| 108 |
+
# Default language column if missing
|
| 109 |
+
if "language" not in df.columns:
|
| 110 |
+
df["language"] = "en"
|
| 111 |
+
|
| 112 |
+
print("π₯ Loading embeddings (this may take a second)...")
|
| 113 |
+
answer_embeddings = load_embeddings(EMBED_PATH, csv_len_expected=len(df))
|
| 114 |
+
# Ensure embeddings are float32 and normalized (we normalize incoming queries; if embeddings not normalized, we normalize here)
|
| 115 |
+
if answer_embeddings.dtype != torch.float32:
|
| 116 |
+
answer_embeddings = answer_embeddings.to(torch.float32)
|
| 117 |
+
# Normalize stored embeddings for dot-product = cosine
|
| 118 |
+
answer_embeddings = torch.nn.functional.normalize(answer_embeddings, p=2, dim=1)
|
| 119 |
+
|
| 120 |
+
print("β
Loaded dataset rows:", len(df), "embedding dim:", answer_embeddings.shape[1])
|
| 121 |
+
|
| 122 |
+
# -----------------------
|
| 123 |
+
# Load MuRIL model for query encoding (AutoModel + tokenizer)
|
| 124 |
+
# -----------------------
|
| 125 |
+
print("βοΈ Loading tokenizer & model for text encoding from:", MODEL_DIR)
|
| 126 |
+
try:
|
| 127 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True)
|
| 128 |
+
encoder_model = AutoModel.from_pretrained(MODEL_DIR)
|
| 129 |
+
encoder_model.to(DEVICE)
|
| 130 |
+
encoder_model.eval()
|
| 131 |
+
except Exception as e:
|
| 132 |
+
# fail fast with helpful message
|
| 133 |
+
raise RuntimeError(
|
| 134 |
+
f"Failed to load model/tokenizer from {MODEL_DIR}. "
|
| 135 |
+
"If the repo stores only adapters/LoRA, you must load base model + apply adapters. "
|
| 136 |
+
"Ensure the repo contains full model files or set MODEL_SUBDIR appropriately."
|
| 137 |
+
) from e
|
| 138 |
+
|
| 139 |
+
def encode_query(texts: List[str], batch_size: int = 32):
|
| 140 |
+
all_embs = []
|
| 141 |
+
with torch.inference_mode():
|
| 142 |
+
for i in range(0, len(texts), batch_size):
|
| 143 |
+
batch = texts[i : i + batch_size]
|
| 144 |
+
enc = tokenizer(batch, padding=True, truncation=True, max_length=256, return_tensors="pt")
|
| 145 |
+
input_ids = enc["input_ids"].to(DEVICE)
|
| 146 |
+
attention_mask = enc["attention_mask"].to(DEVICE)
|
| 147 |
+
out = encoder_model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
|
| 148 |
+
pooled = mean_pool(out.last_hidden_state, attention_mask)
|
| 149 |
+
pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)
|
| 150 |
+
all_embs.append(pooled.cpu())
|
| 151 |
+
return torch.cat(all_embs, dim=0)
|
| 152 |
+
|
| 153 |
+
# -----------------------
|
| 154 |
+
# FastAPI app
|
| 155 |
+
# -----------------------
|
| 156 |
+
app = FastAPI(title="MuRIL Multilingual QA API (Hub-backed artifacts)")
|
| 157 |
+
|
| 158 |
+
class QueryRequest(BaseModel):
|
| 159 |
+
question: str
|
| 160 |
+
lang: Optional[str] = None
|
| 161 |
+
|
| 162 |
+
class QAResponse(BaseModel):
|
| 163 |
+
answer: str
|
| 164 |
+
detected_lang: str
|
| 165 |
+
top_k: Optional[List[dict]] = None
|
| 166 |
+
|
| 167 |
+
@app.get("/")
|
| 168 |
+
def root():
|
| 169 |
+
langs = sorted(df["language"].unique().tolist())
|
| 170 |
+
return {
|
| 171 |
+
"status": "β
MuRIL Multilingual QA API Running",
|
| 172 |
+
"available_languages": langs,
|
| 173 |
+
"model_repo": HF_REPO,
|
| 174 |
+
"loaded_rows": len(df),
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
@app.post("/get-answer", response_model=QAResponse)
|
| 178 |
+
def get_answer_endpoint(req: QueryRequest):
|
| 179 |
+
qtext = (req.question or "").strip()
|
| 180 |
+
if not qtext:
|
| 181 |
+
raise HTTPException(status_code=400, detail="Empty question")
|
| 182 |
+
|
| 183 |
+
lang_filter = (req.lang or "").strip()
|
| 184 |
+
filtered_df = df
|
| 185 |
+
filtered_embeddings = answer_embeddings
|
| 186 |
+
if lang_filter:
|
| 187 |
+
mask = df["language"] == lang_filter
|
| 188 |
+
if not mask.any():
|
| 189 |
+
# no data for this language
|
| 190 |
+
return QAResponse(answer=f"β οΈ No data found for language '{lang_filter}'.", detected_lang=lang_filter)
|
| 191 |
+
filtered_df = df[mask].reset_index(drop=True)
|
| 192 |
+
filtered_embeddings = answer_embeddings[mask.values]
|
| 193 |
+
|
| 194 |
+
# encode query
|
| 195 |
+
q_emb = encode_query([qtext], batch_size=1) # shape (1, D)
|
| 196 |
+
sims = torch.matmul(q_emb, filtered_embeddings.T).squeeze(0) # (N,)
|
| 197 |
+
k = max(1, min(TOP_K, len(filtered_df)))
|
| 198 |
+
topv, topi = torch.topk(sims, k=k)
|
| 199 |
+
best_idx = int(topi[0].item())
|
| 200 |
+
answer = filtered_df.iloc[best_idx]["answer"]
|
| 201 |
+
|
| 202 |
+
top_k_payload = None
|
| 203 |
+
if k > 1:
|
| 204 |
+
top_k_payload = []
|
| 205 |
+
for rank in range(k):
|
| 206 |
+
idx = int(topi[rank].item())
|
| 207 |
+
top_k_payload.append({
|
| 208 |
+
"rank": rank + 1,
|
| 209 |
+
"score": float(topv[rank].item()),
|
| 210 |
+
"answer": filtered_df.iloc[idx]["answer"],
|
| 211 |
+
"question": filtered_df.iloc[idx]["question"],
|
| 212 |
+
})
|
| 213 |
+
|
| 214 |
+
return QAResponse(answer=answer, detected_lang=lang_filter or "all", top_k=top_k_payload)
|
| 215 |
+
|
| 216 |
+
# -----------------------
|
| 217 |
+
# Run server (if invoked directly)
|
| 218 |
+
# -----------------------
|
| 219 |
+
if __name__ == "__main__":
|
| 220 |
+
import uvicorn
|
| 221 |
+
uvicorn.run("main:app", host="0.0.0.0", port=int(os.getenv("PORT", "7860")), workers=1)
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.115.5
|
| 2 |
+
uvicorn[standard]==0.32.0
|
| 3 |
+
pandas==2.2.3
|
| 4 |
+
torch>=2.1.0
|
| 5 |
+
transformers==4.46.0
|
| 6 |
+
huggingface_hub>=0.14.1
|
| 7 |
+
tqdm
|