File size: 8,740 Bytes
ce552a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
# main.py
import os
import sys
import torch
import pandas as pd
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, AutoModel
import numpy as np

# -----------------------
# Configuration (env)
# -----------------------
# Repo containing artifacts (CSV, embeddings, optionally model weights)
HF_REPO = os.getenv(
    "HF_REPO",
    "Sp2503/Finetuned-multilingualdataset-MuriL-model"
)

# Where snapshot_download will cache/put the files
CACHE_DIR = os.getenv("HF_CACHE_DIR", "/tmp/hf_artifacts")

# Names expected inside the repo
CSV_FILENAME = os.getenv("CSV_FILENAME", "muril_multilingual_dataset.csv")
EMBED_FILENAME = os.getenv("EMBED_FILENAME", "answer_embeddings.pt")
# If you stored a fine-tuned model in the same repo, set this to repo path (optional)
MODEL_SUBDIR = os.getenv("MODEL_SUBDIR", "")  # leave empty if model is in root

# retrieval config
TOP_K = int(os.getenv("TOP_K", "1"))
DEVICE = "cuda" if torch.cuda.is_available() and os.getenv("TORCH_DISABLE_CUDA", "0") != "1" else "cpu"

# -----------------------
# Utility helpers
# -----------------------
def download_and_verify(repo_id: str, repo_type: str = "model", cache_dir: str = CACHE_DIR):
    """
    Download repo snapshot and verify that CSV + embeddings exist.
    Returns absolute paths: model_dir, csv_path, embed_path
    """
    print(f"πŸ” snapshot_download: repo_id={repo_id} cache_dir={cache_dir}")
    try:
        model_dir = snapshot_download(repo_id=repo_id, repo_type=repo_type, cache_dir=cache_dir)
    except Exception as e:
        raise RuntimeError(
            f"Failed to snapshot_download repo {repo_id}. "
            "Make sure HUGGINGFACE_HUB_TOKEN (or HF_TOKEN) is set for private repos and that repo exists."
        ) from e

    csv_path = os.path.join(model_dir, CSV_FILENAME)
    embed_path = os.path.join(model_dir, EMBED_FILENAME)
    model_path = os.path.join(model_dir, MODEL_SUBDIR) if MODEL_SUBDIR else model_dir

    # checks
    missing = []
    if not os.path.isfile(csv_path):
        missing.append(csv_path)
    if not os.path.isfile(embed_path):
        missing.append(embed_path)

    if missing:
        raise FileNotFoundError(
            f"Missing artifact(s) in the downloaded repo: {missing}\n"
            f"Push {CSV_FILENAME} and {EMBED_FILENAME} to the repo '{repo_id}' or set HF_REPO to the correct one."
        )

    print(f"βœ… Downloaded snapshot to: {model_dir}")
    print(f"βœ… Found CSV at: {csv_path}")
    print(f"βœ… Found embeddings at: {embed_path}")
    return model_path, csv_path, embed_path

def load_embeddings(emb_path: str, csv_len_expected: Optional[int] = None):
    emb = torch.load(emb_path, map_location="cpu")
    if not isinstance(emb, torch.Tensor):
        raise ValueError(f"Embeddings file {emb_path} did not load a torch.Tensor (type={type(emb)}).")
    if emb.ndim != 2:
        raise ValueError(f"Embeddings tensor must have shape [N, D]. Got {tuple(emb.shape)}.")
    if csv_len_expected is not None and emb.shape[0] != csv_len_expected:
        raise ValueError(f"Mismatch: CSV rows={csv_len_expected} but embeddings rows={emb.shape[0]}. Ensure ordering matches.")
    return emb

def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor):
    mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
    summed = torch.sum(last_hidden_state * mask, dim=1)
    denom = torch.clamp(mask.sum(dim=1), min=1e-9)
    return summed / denom

# -----------------------
# Startup: download artifacts
# -----------------------
try:
    MODEL_DIR, CSV_PATH, EMBED_PATH = download_and_verify(HF_REPO, repo_type="model", cache_dir=CACHE_DIR)
except Exception as e:
    print("βœ– ERROR during snapshot_download or verification:", e, file=sys.stderr)
    raise

# -----------------------
# Load CSV and embeddings
# -----------------------
print("πŸ“₯ Loading CSV:", CSV_PATH)
df = pd.read_csv(CSV_PATH, dtype=str).fillna("")
# Ensure required columns
if not {"question", "answer"}.issubset(set(df.columns)):
    raise RuntimeError(f"CSV must contain 'question' and 'answer' columns. Found: {df.columns.tolist()}")

# Default language column if missing
if "language" not in df.columns:
    df["language"] = "en"

print("πŸ“₯ Loading embeddings (this may take a second)...")
answer_embeddings = load_embeddings(EMBED_PATH, csv_len_expected=len(df))
# Ensure embeddings are float32 and normalized (we normalize incoming queries; if embeddings not normalized, we normalize here)
if answer_embeddings.dtype != torch.float32:
    answer_embeddings = answer_embeddings.to(torch.float32)
# Normalize stored embeddings for dot-product = cosine
answer_embeddings = torch.nn.functional.normalize(answer_embeddings, p=2, dim=1)

print("βœ… Loaded dataset rows:", len(df), "embedding dim:", answer_embeddings.shape[1])

# -----------------------
# Load MuRIL model for query encoding (AutoModel + tokenizer)
# -----------------------
print("βš™οΈ Loading tokenizer & model for text encoding from:", MODEL_DIR)
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True)
    encoder_model = AutoModel.from_pretrained(MODEL_DIR)
    encoder_model.to(DEVICE)
    encoder_model.eval()
except Exception as e:
    # fail fast with helpful message
    raise RuntimeError(
        f"Failed to load model/tokenizer from {MODEL_DIR}. "
        "If the repo stores only adapters/LoRA, you must load base model + apply adapters. "
        "Ensure the repo contains full model files or set MODEL_SUBDIR appropriately."
    ) from e

def encode_query(texts: List[str], batch_size: int = 32):
    all_embs = []
    with torch.inference_mode():
        for i in range(0, len(texts), batch_size):
            batch = texts[i : i + batch_size]
            enc = tokenizer(batch, padding=True, truncation=True, max_length=256, return_tensors="pt")
            input_ids = enc["input_ids"].to(DEVICE)
            attention_mask = enc["attention_mask"].to(DEVICE)
            out = encoder_model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
            pooled = mean_pool(out.last_hidden_state, attention_mask)
            pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)
            all_embs.append(pooled.cpu())
    return torch.cat(all_embs, dim=0)

# -----------------------
# FastAPI app
# -----------------------
app = FastAPI(title="MuRIL Multilingual QA API (Hub-backed artifacts)")

class QueryRequest(BaseModel):
    question: str
    lang: Optional[str] = None

class QAResponse(BaseModel):
    answer: str
    detected_lang: str
    top_k: Optional[List[dict]] = None

@app.get("/")
def root():
    langs = sorted(df["language"].unique().tolist())
    return {
        "status": "βœ… MuRIL Multilingual QA API Running",
        "available_languages": langs,
        "model_repo": HF_REPO,
        "loaded_rows": len(df),
    }

@app.post("/get-answer", response_model=QAResponse)
def get_answer_endpoint(req: QueryRequest):
    qtext = (req.question or "").strip()
    if not qtext:
        raise HTTPException(status_code=400, detail="Empty question")

    lang_filter = (req.lang or "").strip()
    filtered_df = df
    filtered_embeddings = answer_embeddings
    if lang_filter:
        mask = df["language"] == lang_filter
        if not mask.any():
            # no data for this language
            return QAResponse(answer=f"⚠️ No data found for language '{lang_filter}'.", detected_lang=lang_filter)
        filtered_df = df[mask].reset_index(drop=True)
        filtered_embeddings = answer_embeddings[mask.values]

    # encode query
    q_emb = encode_query([qtext], batch_size=1)  # shape (1, D)
    sims = torch.matmul(q_emb, filtered_embeddings.T).squeeze(0)  # (N,)
    k = max(1, min(TOP_K, len(filtered_df)))
    topv, topi = torch.topk(sims, k=k)
    best_idx = int(topi[0].item())
    answer = filtered_df.iloc[best_idx]["answer"]

    top_k_payload = None
    if k > 1:
        top_k_payload = []
        for rank in range(k):
            idx = int(topi[rank].item())
            top_k_payload.append({
                "rank": rank + 1,
                "score": float(topv[rank].item()),
                "answer": filtered_df.iloc[idx]["answer"],
                "question": filtered_df.iloc[idx]["question"],
            })

    return QAResponse(answer=answer, detected_lang=lang_filter or "all", top_k=top_k_payload)

# -----------------------
# Run server (if invoked directly)
# -----------------------
if __name__ == "__main__":
    import uvicorn
    uvicorn.run("main:app", host="0.0.0.0", port=int(os.getenv("PORT", "7860")), workers=1)