# main.py import os import sys import torch import pandas as pd from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List, Optional from huggingface_hub import snapshot_download from transformers import AutoTokenizer, AutoModel import numpy as np # ----------------------- # Configuration (env) # ----------------------- # Repo containing artifacts (CSV, embeddings, optionally model weights) HF_REPO = os.getenv( "HF_REPO", "Sp2503/Finetuned-multilingualdataset-MuriL-model" ) # Where snapshot_download will cache/put the files CACHE_DIR = os.getenv("HF_CACHE_DIR", "/tmp/hf_artifacts") # Names expected inside the repo CSV_FILENAME = os.getenv("CSV_FILENAME", "muril_multilingual_dataset.csv") EMBED_FILENAME = os.getenv("EMBED_FILENAME", "answer_embeddings.pt") # If you stored a fine-tuned model in the same repo, set this to repo path (optional) MODEL_SUBDIR = os.getenv("MODEL_SUBDIR", "") # leave empty if model is in root # retrieval config TOP_K = int(os.getenv("TOP_K", "1")) DEVICE = "cuda" if torch.cuda.is_available() and os.getenv("TORCH_DISABLE_CUDA", "0") != "1" else "cpu" # ----------------------- # Utility helpers # ----------------------- def download_and_verify(repo_id: str, repo_type: str = "model", cache_dir: str = CACHE_DIR): """ Download repo snapshot and verify that CSV + embeddings exist. Returns absolute paths: model_dir, csv_path, embed_path """ print(f"🔁 snapshot_download: repo_id={repo_id} cache_dir={cache_dir}") try: model_dir = snapshot_download(repo_id=repo_id, repo_type=repo_type, cache_dir=cache_dir) except Exception as e: raise RuntimeError( f"Failed to snapshot_download repo {repo_id}. " "Make sure HUGGINGFACE_HUB_TOKEN (or HF_TOKEN) is set for private repos and that repo exists." ) from e csv_path = os.path.join(model_dir, CSV_FILENAME) embed_path = os.path.join(model_dir, EMBED_FILENAME) model_path = os.path.join(model_dir, MODEL_SUBDIR) if MODEL_SUBDIR else model_dir # checks missing = [] if not os.path.isfile(csv_path): missing.append(csv_path) if not os.path.isfile(embed_path): missing.append(embed_path) if missing: raise FileNotFoundError( f"Missing artifact(s) in the downloaded repo: {missing}\n" f"Push {CSV_FILENAME} and {EMBED_FILENAME} to the repo '{repo_id}' or set HF_REPO to the correct one." ) print(f"✅ Downloaded snapshot to: {model_dir}") print(f"✅ Found CSV at: {csv_path}") print(f"✅ Found embeddings at: {embed_path}") return model_path, csv_path, embed_path def load_embeddings(emb_path: str, csv_len_expected: Optional[int] = None): emb = torch.load(emb_path, map_location="cpu") if not isinstance(emb, torch.Tensor): raise ValueError(f"Embeddings file {emb_path} did not load a torch.Tensor (type={type(emb)}).") if emb.ndim != 2: raise ValueError(f"Embeddings tensor must have shape [N, D]. Got {tuple(emb.shape)}.") if csv_len_expected is not None and emb.shape[0] != csv_len_expected: raise ValueError(f"Mismatch: CSV rows={csv_len_expected} but embeddings rows={emb.shape[0]}. Ensure ordering matches.") return emb def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor): mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() summed = torch.sum(last_hidden_state * mask, dim=1) denom = torch.clamp(mask.sum(dim=1), min=1e-9) return summed / denom # ----------------------- # Startup: download artifacts # ----------------------- try: MODEL_DIR, CSV_PATH, EMBED_PATH = download_and_verify(HF_REPO, repo_type="model", cache_dir=CACHE_DIR) except Exception as e: print("✖ ERROR during snapshot_download or verification:", e, file=sys.stderr) raise # ----------------------- # Load CSV and embeddings # ----------------------- print("📥 Loading CSV:", CSV_PATH) df = pd.read_csv(CSV_PATH, dtype=str).fillna("") # Ensure required columns if not {"question", "answer"}.issubset(set(df.columns)): raise RuntimeError(f"CSV must contain 'question' and 'answer' columns. Found: {df.columns.tolist()}") # Default language column if missing if "language" not in df.columns: df["language"] = "en" print("📥 Loading embeddings (this may take a second)...") answer_embeddings = load_embeddings(EMBED_PATH, csv_len_expected=len(df)) # Ensure embeddings are float32 and normalized (we normalize incoming queries; if embeddings not normalized, we normalize here) if answer_embeddings.dtype != torch.float32: answer_embeddings = answer_embeddings.to(torch.float32) # Normalize stored embeddings for dot-product = cosine answer_embeddings = torch.nn.functional.normalize(answer_embeddings, p=2, dim=1) print("✅ Loaded dataset rows:", len(df), "embedding dim:", answer_embeddings.shape[1]) # ----------------------- # Load MuRIL model for query encoding (AutoModel + tokenizer) # ----------------------- print("⚙️ Loading tokenizer & model for text encoding from:", MODEL_DIR) try: tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True) encoder_model = AutoModel.from_pretrained(MODEL_DIR) encoder_model.to(DEVICE) encoder_model.eval() except Exception as e: # fail fast with helpful message raise RuntimeError( f"Failed to load model/tokenizer from {MODEL_DIR}. " "If the repo stores only adapters/LoRA, you must load base model + apply adapters. " "Ensure the repo contains full model files or set MODEL_SUBDIR appropriately." ) from e def encode_query(texts: List[str], batch_size: int = 32): all_embs = [] with torch.inference_mode(): for i in range(0, len(texts), batch_size): batch = texts[i : i + batch_size] enc = tokenizer(batch, padding=True, truncation=True, max_length=256, return_tensors="pt") input_ids = enc["input_ids"].to(DEVICE) attention_mask = enc["attention_mask"].to(DEVICE) out = encoder_model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) pooled = mean_pool(out.last_hidden_state, attention_mask) pooled = torch.nn.functional.normalize(pooled, p=2, dim=1) all_embs.append(pooled.cpu()) return torch.cat(all_embs, dim=0) # ----------------------- # FastAPI app # ----------------------- app = FastAPI(title="MuRIL Multilingual QA API (Hub-backed artifacts)") class QueryRequest(BaseModel): question: str lang: Optional[str] = None class QAResponse(BaseModel): answer: str detected_lang: str top_k: Optional[List[dict]] = None @app.get("/") def root(): langs = sorted(df["language"].unique().tolist()) return { "status": "✅ MuRIL Multilingual QA API Running", "available_languages": langs, "model_repo": HF_REPO, "loaded_rows": len(df), } @app.post("/get-answer", response_model=QAResponse) def get_answer_endpoint(req: QueryRequest): qtext = (req.question or "").strip() if not qtext: raise HTTPException(status_code=400, detail="Empty question") lang_filter = (req.lang or "").strip() filtered_df = df filtered_embeddings = answer_embeddings if lang_filter: mask = df["language"] == lang_filter if not mask.any(): # no data for this language return QAResponse(answer=f"⚠️ No data found for language '{lang_filter}'.", detected_lang=lang_filter) filtered_df = df[mask].reset_index(drop=True) filtered_embeddings = answer_embeddings[mask.values] # encode query q_emb = encode_query([qtext], batch_size=1) # shape (1, D) sims = torch.matmul(q_emb, filtered_embeddings.T).squeeze(0) # (N,) k = max(1, min(TOP_K, len(filtered_df))) topv, topi = torch.topk(sims, k=k) best_idx = int(topi[0].item()) answer = filtered_df.iloc[best_idx]["answer"] top_k_payload = None if k > 1: top_k_payload = [] for rank in range(k): idx = int(topi[rank].item()) top_k_payload.append({ "rank": rank + 1, "score": float(topv[rank].item()), "answer": filtered_df.iloc[idx]["answer"], "question": filtered_df.iloc[idx]["question"], }) return QAResponse(answer=answer, detected_lang=lang_filter or "all", top_k=top_k_payload) # ----------------------- # Run server (if invoked directly) # ----------------------- if __name__ == "__main__": import uvicorn uvicorn.run("main:app", host="0.0.0.0", port=int(os.getenv("PORT", "7860")), workers=1)