|
|
|
|
|
""" |
|
|
Regenerate answer embeddings using the MuRIL model. |
|
|
This script: |
|
|
- downloads model (if MODEL_DIR is a repo id), |
|
|
- reads CSV at CSV_PATH, |
|
|
- computes mean-pooled, L2-normalized embeddings for 'answer' column, |
|
|
- saves embeddings to OUT_EMBED_PATH. |
|
|
|
|
|
Exit codes: |
|
|
- 0 on success |
|
|
- non-zero on failure |
|
|
""" |
|
|
import os, argparse, math, sys |
|
|
from pathlib import Path |
|
|
import torch |
|
|
import pandas as pd |
|
|
from tqdm.auto import tqdm |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
def mean_pooling(last_hidden_state, attention_mask): |
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() |
|
|
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1) |
|
|
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
return sum_embeddings / sum_mask |
|
|
|
|
|
def parse_env(): |
|
|
|
|
|
cfg = {} |
|
|
cfg['model_dir'] = os.getenv("MODEL_DIR", os.getenv("HF_REPO", "Sp2503/Finetuned-multilingualdataset-MuriL-model")) |
|
|
cfg['csv_path'] = os.getenv("CSV_PATH", "/app/export_artifacts/muril_multilingual_dataset.csv") |
|
|
cfg['out_path'] = os.getenv("OUT_EMBED_PATH", "/app/export_artifacts/answer_embeddings.pt") |
|
|
cfg['batch_size'] = int(os.getenv("EMBED_BATCH_SIZE", "64")) |
|
|
cfg['device'] = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu") |
|
|
cfg['download_cache'] = os.getenv("HF_CACHE_DIR", "/tmp/hf_cache") |
|
|
cfg['upload_back'] = os.getenv("UPLOAD_BACK", "false").lower() in ("1","true","yes") |
|
|
cfg['hf_repo'] = os.getenv("HF_REPO", None) |
|
|
return cfg |
|
|
|
|
|
def main(): |
|
|
cfg = parse_env() |
|
|
print("Regenerate embeddings with config:", cfg) |
|
|
model_dir = cfg['model_dir'] |
|
|
|
|
|
if "/" in model_dir and not os.path.isdir(model_dir): |
|
|
print("Detected HF repo id for model. snapshot_download ->", cfg['download_cache']) |
|
|
try: |
|
|
model_dir = snapshot_download(repo_id=cfg['model_dir'], repo_type="model", cache_dir=cfg['download_cache']) |
|
|
print("Downloaded model to:", model_dir) |
|
|
except Exception as e: |
|
|
print("Failed to snapshot_download model:", e, file=sys.stderr) |
|
|
sys.exit(2) |
|
|
|
|
|
csv_path = cfg['csv_path'] |
|
|
out_path = cfg['out_path'] |
|
|
batch_size = cfg['batch_size'] |
|
|
device = cfg['device'] |
|
|
print(f"Loading CSV: {csv_path}") |
|
|
if not os.path.isfile(csv_path): |
|
|
print(f"CSV not found at {csv_path}", file=sys.stderr) |
|
|
sys.exit(3) |
|
|
df = pd.read_csv(csv_path, dtype=str).fillna("") |
|
|
if 'answer' not in df.columns: |
|
|
print("CSV must contain 'answer' column", file=sys.stderr) |
|
|
sys.exit(4) |
|
|
answers = df['answer'].astype(str).tolist() |
|
|
print(f"Encoding {len(answers)} answers on device {device} (batch_size={batch_size})") |
|
|
|
|
|
|
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True) |
|
|
model = AutoModel.from_pretrained(model_dir) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
except Exception as e: |
|
|
print("Failed to load model/tokenizer:", e, file=sys.stderr) |
|
|
sys.exit(5) |
|
|
|
|
|
|
|
|
all_embs = [] |
|
|
try: |
|
|
with torch.inference_mode(): |
|
|
for i in tqdm(range(0, len(answers), batch_size), desc="Batches"): |
|
|
batch = answers[i:i+batch_size] |
|
|
enc = tokenizer(batch, padding=True, truncation=True, max_length=256, return_tensors="pt") |
|
|
input_ids = enc["input_ids"].to(device) |
|
|
attention_mask = enc["attention_mask"].to(device) |
|
|
out = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) |
|
|
pooled = mean_pooling(out.last_hidden_state, attention_mask) |
|
|
pooled = torch.nn.functional.normalize(pooled, p=2, dim=1) |
|
|
all_embs.append(pooled.cpu()) |
|
|
except Exception as e: |
|
|
print("Error during encoding:", e, file=sys.stderr) |
|
|
sys.exit(6) |
|
|
|
|
|
all_embs = torch.cat(all_embs, dim=0) |
|
|
print("Final embeddings shape:", all_embs.shape) |
|
|
Path(out_path).parent.mkdir(parents=True, exist_ok=True) |
|
|
torch.save(all_embs, out_path) |
|
|
print("Saved embeddings to:", out_path) |
|
|
|
|
|
|
|
|
if cfg['upload_back'] and cfg['hf_repo']: |
|
|
try: |
|
|
from huggingface_hub import HfApi |
|
|
api = HfApi() |
|
|
print(f"Uploading {out_path} back to repo {cfg['hf_repo']} ...") |
|
|
api.upload_file( |
|
|
path_or_fileobj=out_path, |
|
|
path_in_repo=os.path.basename(out_path), |
|
|
repo_id=cfg['hf_repo'], |
|
|
repo_type="model", |
|
|
) |
|
|
print("Upload complete.") |
|
|
except Exception as e: |
|
|
print("Upload back failed:", e, file=sys.stderr) |
|
|
|
|
|
|
|
|
norms = (all_embs * all_embs).sum(dim=1) |
|
|
print("Sample norms (should be ~1.0):", norms[:5].tolist()) |
|
|
return 0 |
|
|
|
|
|
if __name__ == "__main__": |
|
|
sys.exit(main()) |
|
|
|