MuriL-2.0 / model_loader.py
Sai809701
updated files
6677176
# model_loader.py
import os
from pathlib import Path
import torch
def load_model_and_tokenizer(model_repo_dir_or_local_path: str, base_model_id: str = "google/muril-base-cased", device: str = None):
"""
Robust loader that:
- loads full model if model.safetensors or pytorch_model.bin exists in model_repo_dir_or_local_path
- otherwise loads base_model_id then applies PEFT adapter from adapter_model.safetensors (if present)
Returns (tokenizer, model_on_device, backend_str)
"""
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
model_dir = Path(model_repo_dir_or_local_path)
# prefer safetensors full model if present
full_model_files = ["model.safetensors", "pytorch_model.bin"]
adapter_files = ["adapter_model.safetensors", "adapter_config.json", "adapter.safetensors"]
# Import delayed to allow environment to control package errors
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
tokenizer = None
model = None
# 1) Try to load full model from model_dir
for f in full_model_files:
full_path = model_dir / f
if full_path.exists():
tokenizer = AutoTokenizer.from_pretrained(str(model_dir), use_fast=True)
try:
model = AutoModel.from_pretrained(str(model_dir), trust_remote_code=False)
backend = "full-AutoModel"
except Exception:
model = AutoModelForSequenceClassification.from_pretrained(str(model_dir), trust_remote_code=False)
backend = "full-AutoModelForSequenceClassification"
model.to(device)
model.eval()
return tokenizer, model, backend
# 2) If full model not present, check for adapter files
has_adapter = any((model_dir / af).exists() for af in adapter_files)
if has_adapter:
try:
tokenizer = AutoTokenizer.from_pretrained(str(model_dir), use_fast=True)
except Exception:
tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_fast=True)
try:
base = AutoModel.from_pretrained(str(model_dir))
base_name = str(model_dir)
except Exception:
base = AutoModel.from_pretrained(base_model_id)
base_name = base_model_id
base.to(device)
base.eval()
try:
from peft import PeftModel
adapter_dir = str(model_dir)
# If adapter files exist, PeftModel.from_pretrained will pick them up
peft_model = PeftModel.from_pretrained(base, adapter_dir, is_trainable=False)
peft_model.to(device)
peft_model.eval()
return tokenizer, peft_model, "peft-attached"
except Exception as e:
raise RuntimeError(f"Failed to load/apply PEFT adapter from {model_dir}: {e}") from e
# 3) Fallback: try direct load (may fail)
try:
tokenizer = AutoTokenizer.from_pretrained(str(model_dir), use_fast=True)
model = AutoModel.from_pretrained(str(model_dir))
model.to(device)
model.eval()
return tokenizer, model, "auto-fallback"
except Exception as e:
raise RuntimeError(f"Unable to load model or adapters from {model_dir}. Error: {e}") from e