MuriL-2.0 / regenerate_embeddings.py
Sai809701
updated files
6677176
# regenerate_embeddings.py
"""
Regenerate answer embeddings using the MuRIL model.
This script:
- downloads model (if MODEL_DIR is a repo id),
- reads CSV at CSV_PATH,
- computes mean-pooled, L2-normalized embeddings for 'answer' column,
- saves embeddings to OUT_EMBED_PATH.
Exit codes:
- 0 on success
- non-zero on failure
"""
import os, argparse, math, sys
from pathlib import Path
import torch
import pandas as pd
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModel
from huggingface_hub import snapshot_download
def mean_pooling(last_hidden_state, attention_mask):
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask
def parse_env():
# ENV-friendly arg parsing
cfg = {}
cfg['model_dir'] = os.getenv("MODEL_DIR", os.getenv("HF_REPO", "Sp2503/Finetuned-multilingualdataset-MuriL-model"))
cfg['csv_path'] = os.getenv("CSV_PATH", "/app/export_artifacts/muril_multilingual_dataset.csv")
cfg['out_path'] = os.getenv("OUT_EMBED_PATH", "/app/export_artifacts/answer_embeddings.pt")
cfg['batch_size'] = int(os.getenv("EMBED_BATCH_SIZE", "64"))
cfg['device'] = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
cfg['download_cache'] = os.getenv("HF_CACHE_DIR", "/tmp/hf_cache")
cfg['upload_back'] = os.getenv("UPLOAD_BACK", "false").lower() in ("1","true","yes")
cfg['hf_repo'] = os.getenv("HF_REPO", None) # used for upload_back if set
return cfg
def main():
cfg = parse_env()
print("Regenerate embeddings with config:", cfg)
model_dir = cfg['model_dir']
# If model_dir looks like a HF repo id (contains '/'), snapshot_download to local cache
if "/" in model_dir and not os.path.isdir(model_dir):
print("Detected HF repo id for model. snapshot_download ->", cfg['download_cache'])
try:
model_dir = snapshot_download(repo_id=cfg['model_dir'], repo_type="model", cache_dir=cfg['download_cache'])
print("Downloaded model to:", model_dir)
except Exception as e:
print("Failed to snapshot_download model:", e, file=sys.stderr)
sys.exit(2)
csv_path = cfg['csv_path']
out_path = cfg['out_path']
batch_size = cfg['batch_size']
device = cfg['device']
print(f"Loading CSV: {csv_path}")
if not os.path.isfile(csv_path):
print(f"CSV not found at {csv_path}", file=sys.stderr)
sys.exit(3)
df = pd.read_csv(csv_path, dtype=str).fillna("")
if 'answer' not in df.columns:
print("CSV must contain 'answer' column", file=sys.stderr)
sys.exit(4)
answers = df['answer'].astype(str).tolist()
print(f"Encoding {len(answers)} answers on device {device} (batch_size={batch_size})")
# Load tokenizer & model
try:
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
model = AutoModel.from_pretrained(model_dir)
model.to(device)
model.eval()
except Exception as e:
print("Failed to load model/tokenizer:", e, file=sys.stderr)
sys.exit(5)
# compute embeddings
all_embs = []
try:
with torch.inference_mode():
for i in tqdm(range(0, len(answers), batch_size), desc="Batches"):
batch = answers[i:i+batch_size]
enc = tokenizer(batch, padding=True, truncation=True, max_length=256, return_tensors="pt")
input_ids = enc["input_ids"].to(device)
attention_mask = enc["attention_mask"].to(device)
out = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
pooled = mean_pooling(out.last_hidden_state, attention_mask) # (B, H)
pooled = torch.nn.functional.normalize(pooled, p=2, dim=1) # L2-normalize
all_embs.append(pooled.cpu())
except Exception as e:
print("Error during encoding:", e, file=sys.stderr)
sys.exit(6)
all_embs = torch.cat(all_embs, dim=0)
print("Final embeddings shape:", all_embs.shape)
Path(out_path).parent.mkdir(parents=True, exist_ok=True)
torch.save(all_embs, out_path)
print("Saved embeddings to:", out_path)
# Optional: upload back to HF repo (requires HF_TOKEN set and HF_REPO)
if cfg['upload_back'] and cfg['hf_repo']:
try:
from huggingface_hub import HfApi
api = HfApi()
print(f"Uploading {out_path} back to repo {cfg['hf_repo']} ...")
api.upload_file(
path_or_fileobj=out_path,
path_in_repo=os.path.basename(out_path),
repo_id=cfg['hf_repo'],
repo_type="model",
)
print("Upload complete.")
except Exception as e:
print("Upload back failed:", e, file=sys.stderr)
# quick sanity check
norms = (all_embs * all_embs).sum(dim=1)
print("Sample norms (should be ~1.0):", norms[:5].tolist())
return 0
if __name__ == "__main__":
sys.exit(main())