|
|
|
|
|
import os |
|
|
import sys |
|
|
import torch |
|
|
import pandas as pd |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
from typing import List, Optional |
|
|
from huggingface_hub import snapshot_download |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
HF_REPO = os.getenv( |
|
|
"HF_REPO", |
|
|
"Sp2503/Finetuned-multilingualdataset-MuriL-model" |
|
|
) |
|
|
|
|
|
|
|
|
CACHE_DIR = os.getenv("HF_CACHE_DIR", "/tmp/hf_artifacts") |
|
|
|
|
|
|
|
|
CSV_FILENAME = os.getenv("CSV_FILENAME", "muril_multilingual_dataset.csv") |
|
|
EMBED_FILENAME = os.getenv("EMBED_FILENAME", "answer_embeddings.pt") |
|
|
|
|
|
MODEL_SUBDIR = os.getenv("MODEL_SUBDIR", "") |
|
|
|
|
|
|
|
|
TOP_K = int(os.getenv("TOP_K", "1")) |
|
|
DEVICE = "cuda" if torch.cuda.is_available() and os.getenv("TORCH_DISABLE_CUDA", "0") != "1" else "cpu" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def download_and_verify(repo_id: str, repo_type: str = "model", cache_dir: str = CACHE_DIR): |
|
|
""" |
|
|
Download repo snapshot and verify that CSV + embeddings exist. |
|
|
Returns absolute paths: model_dir, csv_path, embed_path |
|
|
""" |
|
|
print(f"π snapshot_download: repo_id={repo_id} cache_dir={cache_dir}") |
|
|
try: |
|
|
model_dir = snapshot_download(repo_id=repo_id, repo_type=repo_type, cache_dir=cache_dir) |
|
|
except Exception as e: |
|
|
raise RuntimeError( |
|
|
f"Failed to snapshot_download repo {repo_id}. " |
|
|
"Make sure HUGGINGFACE_HUB_TOKEN (or HF_TOKEN) is set for private repos and that repo exists." |
|
|
) from e |
|
|
|
|
|
csv_path = os.path.join(model_dir, CSV_FILENAME) |
|
|
embed_path = os.path.join(model_dir, EMBED_FILENAME) |
|
|
model_path = os.path.join(model_dir, MODEL_SUBDIR) if MODEL_SUBDIR else model_dir |
|
|
|
|
|
|
|
|
missing = [] |
|
|
if not os.path.isfile(csv_path): |
|
|
missing.append(csv_path) |
|
|
if not os.path.isfile(embed_path): |
|
|
missing.append(embed_path) |
|
|
|
|
|
if missing: |
|
|
raise FileNotFoundError( |
|
|
f"Missing artifact(s) in the downloaded repo: {missing}\n" |
|
|
f"Push {CSV_FILENAME} and {EMBED_FILENAME} to the repo '{repo_id}' or set HF_REPO to the correct one." |
|
|
) |
|
|
|
|
|
print(f"β
Downloaded snapshot to: {model_dir}") |
|
|
print(f"β
Found CSV at: {csv_path}") |
|
|
print(f"β
Found embeddings at: {embed_path}") |
|
|
return model_path, csv_path, embed_path |
|
|
|
|
|
def load_embeddings(emb_path: str, csv_len_expected: Optional[int] = None): |
|
|
emb = torch.load(emb_path, map_location="cpu") |
|
|
if not isinstance(emb, torch.Tensor): |
|
|
raise ValueError(f"Embeddings file {emb_path} did not load a torch.Tensor (type={type(emb)}).") |
|
|
if emb.ndim != 2: |
|
|
raise ValueError(f"Embeddings tensor must have shape [N, D]. Got {tuple(emb.shape)}.") |
|
|
if csv_len_expected is not None and emb.shape[0] != csv_len_expected: |
|
|
raise ValueError(f"Mismatch: CSV rows={csv_len_expected} but embeddings rows={emb.shape[0]}. Ensure ordering matches.") |
|
|
return emb |
|
|
|
|
|
def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor): |
|
|
mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() |
|
|
summed = torch.sum(last_hidden_state * mask, dim=1) |
|
|
denom = torch.clamp(mask.sum(dim=1), min=1e-9) |
|
|
return summed / denom |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
MODEL_DIR, CSV_PATH, EMBED_PATH = download_and_verify(HF_REPO, repo_type="model", cache_dir=CACHE_DIR) |
|
|
except Exception as e: |
|
|
print("β ERROR during snapshot_download or verification:", e, file=sys.stderr) |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("π₯ Loading CSV:", CSV_PATH) |
|
|
df = pd.read_csv(CSV_PATH, dtype=str).fillna("") |
|
|
|
|
|
if not {"question", "answer"}.issubset(set(df.columns)): |
|
|
raise RuntimeError(f"CSV must contain 'question' and 'answer' columns. Found: {df.columns.tolist()}") |
|
|
|
|
|
|
|
|
if "language" not in df.columns: |
|
|
df["language"] = "en" |
|
|
|
|
|
print("π₯ Loading embeddings (this may take a second)...") |
|
|
answer_embeddings = load_embeddings(EMBED_PATH, csv_len_expected=len(df)) |
|
|
|
|
|
if answer_embeddings.dtype != torch.float32: |
|
|
answer_embeddings = answer_embeddings.to(torch.float32) |
|
|
|
|
|
answer_embeddings = torch.nn.functional.normalize(answer_embeddings, p=2, dim=1) |
|
|
|
|
|
print("β
Loaded dataset rows:", len(df), "embedding dim:", answer_embeddings.shape[1]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("βοΈ Loading tokenizer & model for text encoding from:", MODEL_DIR) |
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True) |
|
|
encoder_model = AutoModel.from_pretrained(MODEL_DIR) |
|
|
encoder_model.to(DEVICE) |
|
|
encoder_model.eval() |
|
|
except Exception as e: |
|
|
|
|
|
raise RuntimeError( |
|
|
f"Failed to load model/tokenizer from {MODEL_DIR}. " |
|
|
"If the repo stores only adapters/LoRA, you must load base model + apply adapters. " |
|
|
"Ensure the repo contains full model files or set MODEL_SUBDIR appropriately." |
|
|
) from e |
|
|
|
|
|
def encode_query(texts: List[str], batch_size: int = 32): |
|
|
all_embs = [] |
|
|
with torch.inference_mode(): |
|
|
for i in range(0, len(texts), batch_size): |
|
|
batch = texts[i : i + batch_size] |
|
|
enc = tokenizer(batch, padding=True, truncation=True, max_length=256, return_tensors="pt") |
|
|
input_ids = enc["input_ids"].to(DEVICE) |
|
|
attention_mask = enc["attention_mask"].to(DEVICE) |
|
|
out = encoder_model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) |
|
|
pooled = mean_pool(out.last_hidden_state, attention_mask) |
|
|
pooled = torch.nn.functional.normalize(pooled, p=2, dim=1) |
|
|
all_embs.append(pooled.cpu()) |
|
|
return torch.cat(all_embs, dim=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(title="MuRIL Multilingual QA API (Hub-backed artifacts)") |
|
|
|
|
|
class QueryRequest(BaseModel): |
|
|
question: str |
|
|
lang: Optional[str] = None |
|
|
|
|
|
class QAResponse(BaseModel): |
|
|
answer: str |
|
|
detected_lang: str |
|
|
top_k: Optional[List[dict]] = None |
|
|
|
|
|
@app.get("/") |
|
|
def root(): |
|
|
langs = sorted(df["language"].unique().tolist()) |
|
|
return { |
|
|
"status": "β
MuRIL Multilingual QA API Running", |
|
|
"available_languages": langs, |
|
|
"model_repo": HF_REPO, |
|
|
"loaded_rows": len(df), |
|
|
} |
|
|
|
|
|
@app.post("/get-answer", response_model=QAResponse) |
|
|
def get_answer_endpoint(req: QueryRequest): |
|
|
qtext = (req.question or "").strip() |
|
|
if not qtext: |
|
|
raise HTTPException(status_code=400, detail="Empty question") |
|
|
|
|
|
lang_filter = (req.lang or "").strip() |
|
|
filtered_df = df |
|
|
filtered_embeddings = answer_embeddings |
|
|
if lang_filter: |
|
|
mask = df["language"] == lang_filter |
|
|
if not mask.any(): |
|
|
|
|
|
return QAResponse(answer=f"β οΈ No data found for language '{lang_filter}'.", detected_lang=lang_filter) |
|
|
filtered_df = df[mask].reset_index(drop=True) |
|
|
filtered_embeddings = answer_embeddings[mask.values] |
|
|
|
|
|
|
|
|
q_emb = encode_query([qtext], batch_size=1) |
|
|
sims = torch.matmul(q_emb, filtered_embeddings.T).squeeze(0) |
|
|
k = max(1, min(TOP_K, len(filtered_df))) |
|
|
topv, topi = torch.topk(sims, k=k) |
|
|
best_idx = int(topi[0].item()) |
|
|
answer = filtered_df.iloc[best_idx]["answer"] |
|
|
|
|
|
top_k_payload = None |
|
|
if k > 1: |
|
|
top_k_payload = [] |
|
|
for rank in range(k): |
|
|
idx = int(topi[rank].item()) |
|
|
top_k_payload.append({ |
|
|
"rank": rank + 1, |
|
|
"score": float(topv[rank].item()), |
|
|
"answer": filtered_df.iloc[idx]["answer"], |
|
|
"question": filtered_df.iloc[idx]["question"], |
|
|
}) |
|
|
|
|
|
return QAResponse(answer=answer, detected_lang=lang_filter or "all", top_k=top_k_payload) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run("main:app", host="0.0.0.0", port=int(os.getenv("PORT", "7860")), workers=1) |
|
|
|