MuriL-2.0 / main.py
Sai809701
muril api
ce552a1
raw
history blame
8.74 kB
# main.py
import os
import sys
import torch
import pandas as pd
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, AutoModel
import numpy as np
# -----------------------
# Configuration (env)
# -----------------------
# Repo containing artifacts (CSV, embeddings, optionally model weights)
HF_REPO = os.getenv(
"HF_REPO",
"Sp2503/Finetuned-multilingualdataset-MuriL-model"
)
# Where snapshot_download will cache/put the files
CACHE_DIR = os.getenv("HF_CACHE_DIR", "/tmp/hf_artifacts")
# Names expected inside the repo
CSV_FILENAME = os.getenv("CSV_FILENAME", "muril_multilingual_dataset.csv")
EMBED_FILENAME = os.getenv("EMBED_FILENAME", "answer_embeddings.pt")
# If you stored a fine-tuned model in the same repo, set this to repo path (optional)
MODEL_SUBDIR = os.getenv("MODEL_SUBDIR", "") # leave empty if model is in root
# retrieval config
TOP_K = int(os.getenv("TOP_K", "1"))
DEVICE = "cuda" if torch.cuda.is_available() and os.getenv("TORCH_DISABLE_CUDA", "0") != "1" else "cpu"
# -----------------------
# Utility helpers
# -----------------------
def download_and_verify(repo_id: str, repo_type: str = "model", cache_dir: str = CACHE_DIR):
"""
Download repo snapshot and verify that CSV + embeddings exist.
Returns absolute paths: model_dir, csv_path, embed_path
"""
print(f"πŸ” snapshot_download: repo_id={repo_id} cache_dir={cache_dir}")
try:
model_dir = snapshot_download(repo_id=repo_id, repo_type=repo_type, cache_dir=cache_dir)
except Exception as e:
raise RuntimeError(
f"Failed to snapshot_download repo {repo_id}. "
"Make sure HUGGINGFACE_HUB_TOKEN (or HF_TOKEN) is set for private repos and that repo exists."
) from e
csv_path = os.path.join(model_dir, CSV_FILENAME)
embed_path = os.path.join(model_dir, EMBED_FILENAME)
model_path = os.path.join(model_dir, MODEL_SUBDIR) if MODEL_SUBDIR else model_dir
# checks
missing = []
if not os.path.isfile(csv_path):
missing.append(csv_path)
if not os.path.isfile(embed_path):
missing.append(embed_path)
if missing:
raise FileNotFoundError(
f"Missing artifact(s) in the downloaded repo: {missing}\n"
f"Push {CSV_FILENAME} and {EMBED_FILENAME} to the repo '{repo_id}' or set HF_REPO to the correct one."
)
print(f"βœ… Downloaded snapshot to: {model_dir}")
print(f"βœ… Found CSV at: {csv_path}")
print(f"βœ… Found embeddings at: {embed_path}")
return model_path, csv_path, embed_path
def load_embeddings(emb_path: str, csv_len_expected: Optional[int] = None):
emb = torch.load(emb_path, map_location="cpu")
if not isinstance(emb, torch.Tensor):
raise ValueError(f"Embeddings file {emb_path} did not load a torch.Tensor (type={type(emb)}).")
if emb.ndim != 2:
raise ValueError(f"Embeddings tensor must have shape [N, D]. Got {tuple(emb.shape)}.")
if csv_len_expected is not None and emb.shape[0] != csv_len_expected:
raise ValueError(f"Mismatch: CSV rows={csv_len_expected} but embeddings rows={emb.shape[0]}. Ensure ordering matches.")
return emb
def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor):
mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
summed = torch.sum(last_hidden_state * mask, dim=1)
denom = torch.clamp(mask.sum(dim=1), min=1e-9)
return summed / denom
# -----------------------
# Startup: download artifacts
# -----------------------
try:
MODEL_DIR, CSV_PATH, EMBED_PATH = download_and_verify(HF_REPO, repo_type="model", cache_dir=CACHE_DIR)
except Exception as e:
print("βœ– ERROR during snapshot_download or verification:", e, file=sys.stderr)
raise
# -----------------------
# Load CSV and embeddings
# -----------------------
print("πŸ“₯ Loading CSV:", CSV_PATH)
df = pd.read_csv(CSV_PATH, dtype=str).fillna("")
# Ensure required columns
if not {"question", "answer"}.issubset(set(df.columns)):
raise RuntimeError(f"CSV must contain 'question' and 'answer' columns. Found: {df.columns.tolist()}")
# Default language column if missing
if "language" not in df.columns:
df["language"] = "en"
print("πŸ“₯ Loading embeddings (this may take a second)...")
answer_embeddings = load_embeddings(EMBED_PATH, csv_len_expected=len(df))
# Ensure embeddings are float32 and normalized (we normalize incoming queries; if embeddings not normalized, we normalize here)
if answer_embeddings.dtype != torch.float32:
answer_embeddings = answer_embeddings.to(torch.float32)
# Normalize stored embeddings for dot-product = cosine
answer_embeddings = torch.nn.functional.normalize(answer_embeddings, p=2, dim=1)
print("βœ… Loaded dataset rows:", len(df), "embedding dim:", answer_embeddings.shape[1])
# -----------------------
# Load MuRIL model for query encoding (AutoModel + tokenizer)
# -----------------------
print("βš™οΈ Loading tokenizer & model for text encoding from:", MODEL_DIR)
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True)
encoder_model = AutoModel.from_pretrained(MODEL_DIR)
encoder_model.to(DEVICE)
encoder_model.eval()
except Exception as e:
# fail fast with helpful message
raise RuntimeError(
f"Failed to load model/tokenizer from {MODEL_DIR}. "
"If the repo stores only adapters/LoRA, you must load base model + apply adapters. "
"Ensure the repo contains full model files or set MODEL_SUBDIR appropriately."
) from e
def encode_query(texts: List[str], batch_size: int = 32):
all_embs = []
with torch.inference_mode():
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
enc = tokenizer(batch, padding=True, truncation=True, max_length=256, return_tensors="pt")
input_ids = enc["input_ids"].to(DEVICE)
attention_mask = enc["attention_mask"].to(DEVICE)
out = encoder_model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
pooled = mean_pool(out.last_hidden_state, attention_mask)
pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)
all_embs.append(pooled.cpu())
return torch.cat(all_embs, dim=0)
# -----------------------
# FastAPI app
# -----------------------
app = FastAPI(title="MuRIL Multilingual QA API (Hub-backed artifacts)")
class QueryRequest(BaseModel):
question: str
lang: Optional[str] = None
class QAResponse(BaseModel):
answer: str
detected_lang: str
top_k: Optional[List[dict]] = None
@app.get("/")
def root():
langs = sorted(df["language"].unique().tolist())
return {
"status": "βœ… MuRIL Multilingual QA API Running",
"available_languages": langs,
"model_repo": HF_REPO,
"loaded_rows": len(df),
}
@app.post("/get-answer", response_model=QAResponse)
def get_answer_endpoint(req: QueryRequest):
qtext = (req.question or "").strip()
if not qtext:
raise HTTPException(status_code=400, detail="Empty question")
lang_filter = (req.lang or "").strip()
filtered_df = df
filtered_embeddings = answer_embeddings
if lang_filter:
mask = df["language"] == lang_filter
if not mask.any():
# no data for this language
return QAResponse(answer=f"⚠️ No data found for language '{lang_filter}'.", detected_lang=lang_filter)
filtered_df = df[mask].reset_index(drop=True)
filtered_embeddings = answer_embeddings[mask.values]
# encode query
q_emb = encode_query([qtext], batch_size=1) # shape (1, D)
sims = torch.matmul(q_emb, filtered_embeddings.T).squeeze(0) # (N,)
k = max(1, min(TOP_K, len(filtered_df)))
topv, topi = torch.topk(sims, k=k)
best_idx = int(topi[0].item())
answer = filtered_df.iloc[best_idx]["answer"]
top_k_payload = None
if k > 1:
top_k_payload = []
for rank in range(k):
idx = int(topi[rank].item())
top_k_payload.append({
"rank": rank + 1,
"score": float(topv[rank].item()),
"answer": filtered_df.iloc[idx]["answer"],
"question": filtered_df.iloc[idx]["question"],
})
return QAResponse(answer=answer, detected_lang=lang_filter or "all", top_k=top_k_payload)
# -----------------------
# Run server (if invoked directly)
# -----------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=int(os.getenv("PORT", "7860")), workers=1)