File size: 8,740 Bytes
ce552a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
# main.py
import os
import sys
import torch
import pandas as pd
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, AutoModel
import numpy as np
# -----------------------
# Configuration (env)
# -----------------------
# Repo containing artifacts (CSV, embeddings, optionally model weights)
HF_REPO = os.getenv(
"HF_REPO",
"Sp2503/Finetuned-multilingualdataset-MuriL-model"
)
# Where snapshot_download will cache/put the files
CACHE_DIR = os.getenv("HF_CACHE_DIR", "/tmp/hf_artifacts")
# Names expected inside the repo
CSV_FILENAME = os.getenv("CSV_FILENAME", "muril_multilingual_dataset.csv")
EMBED_FILENAME = os.getenv("EMBED_FILENAME", "answer_embeddings.pt")
# If you stored a fine-tuned model in the same repo, set this to repo path (optional)
MODEL_SUBDIR = os.getenv("MODEL_SUBDIR", "") # leave empty if model is in root
# retrieval config
TOP_K = int(os.getenv("TOP_K", "1"))
DEVICE = "cuda" if torch.cuda.is_available() and os.getenv("TORCH_DISABLE_CUDA", "0") != "1" else "cpu"
# -----------------------
# Utility helpers
# -----------------------
def download_and_verify(repo_id: str, repo_type: str = "model", cache_dir: str = CACHE_DIR):
"""
Download repo snapshot and verify that CSV + embeddings exist.
Returns absolute paths: model_dir, csv_path, embed_path
"""
print(f"π snapshot_download: repo_id={repo_id} cache_dir={cache_dir}")
try:
model_dir = snapshot_download(repo_id=repo_id, repo_type=repo_type, cache_dir=cache_dir)
except Exception as e:
raise RuntimeError(
f"Failed to snapshot_download repo {repo_id}. "
"Make sure HUGGINGFACE_HUB_TOKEN (or HF_TOKEN) is set for private repos and that repo exists."
) from e
csv_path = os.path.join(model_dir, CSV_FILENAME)
embed_path = os.path.join(model_dir, EMBED_FILENAME)
model_path = os.path.join(model_dir, MODEL_SUBDIR) if MODEL_SUBDIR else model_dir
# checks
missing = []
if not os.path.isfile(csv_path):
missing.append(csv_path)
if not os.path.isfile(embed_path):
missing.append(embed_path)
if missing:
raise FileNotFoundError(
f"Missing artifact(s) in the downloaded repo: {missing}\n"
f"Push {CSV_FILENAME} and {EMBED_FILENAME} to the repo '{repo_id}' or set HF_REPO to the correct one."
)
print(f"β
Downloaded snapshot to: {model_dir}")
print(f"β
Found CSV at: {csv_path}")
print(f"β
Found embeddings at: {embed_path}")
return model_path, csv_path, embed_path
def load_embeddings(emb_path: str, csv_len_expected: Optional[int] = None):
emb = torch.load(emb_path, map_location="cpu")
if not isinstance(emb, torch.Tensor):
raise ValueError(f"Embeddings file {emb_path} did not load a torch.Tensor (type={type(emb)}).")
if emb.ndim != 2:
raise ValueError(f"Embeddings tensor must have shape [N, D]. Got {tuple(emb.shape)}.")
if csv_len_expected is not None and emb.shape[0] != csv_len_expected:
raise ValueError(f"Mismatch: CSV rows={csv_len_expected} but embeddings rows={emb.shape[0]}. Ensure ordering matches.")
return emb
def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor):
mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
summed = torch.sum(last_hidden_state * mask, dim=1)
denom = torch.clamp(mask.sum(dim=1), min=1e-9)
return summed / denom
# -----------------------
# Startup: download artifacts
# -----------------------
try:
MODEL_DIR, CSV_PATH, EMBED_PATH = download_and_verify(HF_REPO, repo_type="model", cache_dir=CACHE_DIR)
except Exception as e:
print("β ERROR during snapshot_download or verification:", e, file=sys.stderr)
raise
# -----------------------
# Load CSV and embeddings
# -----------------------
print("π₯ Loading CSV:", CSV_PATH)
df = pd.read_csv(CSV_PATH, dtype=str).fillna("")
# Ensure required columns
if not {"question", "answer"}.issubset(set(df.columns)):
raise RuntimeError(f"CSV must contain 'question' and 'answer' columns. Found: {df.columns.tolist()}")
# Default language column if missing
if "language" not in df.columns:
df["language"] = "en"
print("π₯ Loading embeddings (this may take a second)...")
answer_embeddings = load_embeddings(EMBED_PATH, csv_len_expected=len(df))
# Ensure embeddings are float32 and normalized (we normalize incoming queries; if embeddings not normalized, we normalize here)
if answer_embeddings.dtype != torch.float32:
answer_embeddings = answer_embeddings.to(torch.float32)
# Normalize stored embeddings for dot-product = cosine
answer_embeddings = torch.nn.functional.normalize(answer_embeddings, p=2, dim=1)
print("β
Loaded dataset rows:", len(df), "embedding dim:", answer_embeddings.shape[1])
# -----------------------
# Load MuRIL model for query encoding (AutoModel + tokenizer)
# -----------------------
print("βοΈ Loading tokenizer & model for text encoding from:", MODEL_DIR)
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True)
encoder_model = AutoModel.from_pretrained(MODEL_DIR)
encoder_model.to(DEVICE)
encoder_model.eval()
except Exception as e:
# fail fast with helpful message
raise RuntimeError(
f"Failed to load model/tokenizer from {MODEL_DIR}. "
"If the repo stores only adapters/LoRA, you must load base model + apply adapters. "
"Ensure the repo contains full model files or set MODEL_SUBDIR appropriately."
) from e
def encode_query(texts: List[str], batch_size: int = 32):
all_embs = []
with torch.inference_mode():
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
enc = tokenizer(batch, padding=True, truncation=True, max_length=256, return_tensors="pt")
input_ids = enc["input_ids"].to(DEVICE)
attention_mask = enc["attention_mask"].to(DEVICE)
out = encoder_model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
pooled = mean_pool(out.last_hidden_state, attention_mask)
pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)
all_embs.append(pooled.cpu())
return torch.cat(all_embs, dim=0)
# -----------------------
# FastAPI app
# -----------------------
app = FastAPI(title="MuRIL Multilingual QA API (Hub-backed artifacts)")
class QueryRequest(BaseModel):
question: str
lang: Optional[str] = None
class QAResponse(BaseModel):
answer: str
detected_lang: str
top_k: Optional[List[dict]] = None
@app.get("/")
def root():
langs = sorted(df["language"].unique().tolist())
return {
"status": "β
MuRIL Multilingual QA API Running",
"available_languages": langs,
"model_repo": HF_REPO,
"loaded_rows": len(df),
}
@app.post("/get-answer", response_model=QAResponse)
def get_answer_endpoint(req: QueryRequest):
qtext = (req.question or "").strip()
if not qtext:
raise HTTPException(status_code=400, detail="Empty question")
lang_filter = (req.lang or "").strip()
filtered_df = df
filtered_embeddings = answer_embeddings
if lang_filter:
mask = df["language"] == lang_filter
if not mask.any():
# no data for this language
return QAResponse(answer=f"β οΈ No data found for language '{lang_filter}'.", detected_lang=lang_filter)
filtered_df = df[mask].reset_index(drop=True)
filtered_embeddings = answer_embeddings[mask.values]
# encode query
q_emb = encode_query([qtext], batch_size=1) # shape (1, D)
sims = torch.matmul(q_emb, filtered_embeddings.T).squeeze(0) # (N,)
k = max(1, min(TOP_K, len(filtered_df)))
topv, topi = torch.topk(sims, k=k)
best_idx = int(topi[0].item())
answer = filtered_df.iloc[best_idx]["answer"]
top_k_payload = None
if k > 1:
top_k_payload = []
for rank in range(k):
idx = int(topi[rank].item())
top_k_payload.append({
"rank": rank + 1,
"score": float(topv[rank].item()),
"answer": filtered_df.iloc[idx]["answer"],
"question": filtered_df.iloc[idx]["question"],
})
return QAResponse(answer=answer, detected_lang=lang_filter or "all", top_k=top_k_payload)
# -----------------------
# Run server (if invoked directly)
# -----------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=int(os.getenv("PORT", "7860")), workers=1)
|