bilalsm's picture
Upload modeling_knn.py with huggingface_hub
f3c503f verified
import os
import os
import json
from typing import Any, Dict, Optional, List
import joblib
from transformers import PreTrainedModel
from .configuration_knn import KNNConfig
class KNNModel(PreTrainedModel):
"""
A tiny wrapper so an sklearn KNN (joblib) can be saved/loaded with
the transformers save_pretrained / from_pretrained pattern.
Notes:
- We persist the sklearn object as `model.joblib` inside the folder.
- Loading from the Hub via `transformers` will require
`trust_remote_code=True` or using this module locally.
"""
config_class = KNNConfig
base_model_prefix = "knn"
def __init__(self, config: KNNConfig, model: Optional[Any] = None, models: Optional[List] = None):
super().__init__(config)
# self.knn is the actual sklearn KNN object (e.g., sklearn.neighbors.KNeighborsClassifier)
# for single models
self.knn = model
# self.models is a list of sklearn KNN objects for ensemble models
self.models = models or []
self.is_ensemble = config.is_ensemble or len(self.models) > 1
def forward(self, X, **kwargs):
"""Return predictions for an input array-like X.
For ensemble models, uses the first model's predictions.
(You can implement voting/averaging logic here if desired)
This is intentionally simple; you can adapt to return ModelOutput
structured objects if desired.
"""
if self.is_ensemble and self.models:
# Use first model for now; could implement ensemble voting
return self.models[0].predict(X)
elif self.knn is not None:
return self.knn.predict(X)
else:
raise ValueError("Model not loaded. Call from_pretrained or load a joblib model first.")
def save_pretrained(self, save_directory: str, **kwargs) -> None:
"""
Save only the config and the sklearn object(s).
We intentionally avoid calling the parent `save_pretrained` because the
transformers implementation expects a PyTorch model (and tries to infer
a `dtype` from model tensors), which fails for non-torch objects and
raises the IndexError seen in CI/when running locally. Instead we use
the config's `save_pretrained` method and persist the sklearn object
as `model.joblib` (or multiple files for ensembles).
"""
os.makedirs(save_directory, exist_ok=True)
# save config.json (PretrainedConfig handles this)
self.config.save_pretrained(save_directory)
# persist sklearn object(s) as joblib
if self.is_ensemble and self.models:
# Save each ensemble member with its original filename
for i, (member_name, model_obj) in enumerate(zip(self.config.ensemble_members, self.models)):
out_path = os.path.join(save_directory, member_name)
os.makedirs(os.path.dirname(out_path), exist_ok=True)
joblib.dump(model_obj, out_path)
elif self.knn is not None:
out_path = os.path.join(save_directory, "model.joblib")
joblib.dump(self.knn, out_path)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):
"""
Load a KNN model with optional variant selection.
Supports two modes:
1. Direct loading: loads model.joblib from the specified path/repo
2. Variant selection: specify parameters to auto-select a model variant
For ensemble models (7T-21T, Synthetic), automatically loads all sub-models.
Args:
pretrained_model_name_or_path: Local path or HF Hub repo ID
data_source: Optional. One of: "7T", "21T", "7T-21T", "Synthetic"
k_neighbors: Optional. 1 or 3
metric: Optional. "euclidean" or "manhattan"
training_version: Optional. For single models only, ignored for ensembles
variant: Optional. Direct variant name (e.g., "knn_21T_k1_euclidean")
Examples:
# Load default best model
model = KNNModel.from_pretrained("SaeedLab/dom-formula-assignment-using-knn")
# Load specific variant by parameters
model = KNNModel.from_pretrained(
"SaeedLab/dom-formula-assignment-using-knn",
data_source="7T-21T", # This is an ensemble!
k_neighbors=1,
metric="euclidean"
)
# Load by variant name
model = KNNModel.from_pretrained(
"SaeedLab/dom-formula-assignment-using-knn",
variant="knn_21T_k3_manhattan"
)
"""
# Extract variant selection parameters
data_source = kwargs.pop("data_source", None)
k_neighbors = kwargs.pop("k_neighbors", None)
metric = kwargs.pop("metric", None)
training_version = kwargs.pop("training_version", None)
variant = kwargs.pop("variant", None)
# Determine if this is an ensemble model
is_ensemble = data_source in ["7T-21T", "Synthetic"] if data_source else False
# load config using parent machinery (handles repo id or local path)
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
# Update config with variant info if provided
if k_neighbors is not None:
config.n_neighbors = k_neighbors
if metric is not None:
config.metric = metric
if data_source is not None:
config.data_source = data_source
if training_version is not None:
config.training_version = training_version
if is_ensemble:
# Load ensemble model (multiple joblib files)
model_filenames = cls._resolve_ensemble_filenames(
pretrained_model_name_or_path,
variant=variant,
data_source=data_source,
k_neighbors=k_neighbors,
metric=metric,
)
config.is_ensemble = True
config.ensemble_members = model_filenames
models = []
for model_filename in model_filenames:
model_file = os.path.join(pretrained_model_name_or_path, model_filename)
if os.path.exists(model_file):
knn = joblib.load(model_file)
else:
# try to download from hub
try:
from huggingface_hub import hf_hub_download
repo_id = pretrained_model_name_or_path
model_path = hf_hub_download(
repo_id=repo_id,
filename=model_filename,
**kwargs.get("hub_kwargs", {})
)
knn = joblib.load(model_path)
except Exception as exc:
raise RuntimeError(
f"Could not locate or download {model_filename} for {pretrained_model_name_or_path}: {exc}"
)
models.append(knn)
inst = cls(config=config, models=models)
return inst
else:
# Load single model
model_filename = cls._resolve_model_filename(
pretrained_model_name_or_path,
variant=variant,
data_source=data_source,
k_neighbors=k_neighbors,
metric=metric,
training_version=training_version,
)
config.is_ensemble = False
# Attempt to resolve model file
model_file = os.path.join(pretrained_model_name_or_path, model_filename)
if os.path.exists(model_file):
knn = joblib.load(model_file)
else:
# try to download from hub
try:
from huggingface_hub import hf_hub_download
repo_id = pretrained_model_name_or_path
model_path = hf_hub_download(
repo_id=repo_id,
filename=model_filename,
**kwargs.get("hub_kwargs", {})
)
knn = joblib.load(model_path)
except Exception as exc:
raise RuntimeError(
f"Could not locate or download {model_filename} for {pretrained_model_name_or_path}: {exc}"
)
inst = cls(config=config, model=knn)
return inst
@staticmethod
def _resolve_model_filename(
pretrained_model_name_or_path: str,
variant: Optional[str] = None,
data_source: Optional[str] = None,
k_neighbors: Optional[int] = None,
metric: Optional[str] = None,
training_version: Optional[str] = None,
) -> str:
"""
Resolve the model filename based on variant parameters.
Returns:
Filename of the .joblib model to load (e.g., "models/knn_21T_k1_euclidean.joblib")
"""
# If direct variant name provided, use it
if variant:
# Ensure .joblib extension
if not variant.endswith(".joblib"):
variant = f"{variant}.joblib"
# Check if it needs models/ prefix
if not variant.startswith("models/"):
return f"models/{variant}"
return variant
# If no parameters provided, use default (best performing model)
if not any([data_source, k_neighbors, metric, training_version]):
return "models/knn_21T_k1_euclidean.joblib"
# Try to load model index to find matching variant
try:
index_path = os.path.join(pretrained_model_name_or_path, "model_index.json")
if os.path.exists(index_path):
with open(index_path, "r") as f:
index = json.load(f)
else:
# Try to download from hub
from huggingface_hub import hf_hub_download
index_path = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="model_index.json"
)
with open(index_path, "r") as f:
index = json.load(f)
# Search for matching variant
for variant_name, variant_info in index.get("variants", {}).items():
matches = True
if data_source and variant_info.get("data_source") != data_source:
matches = False
if k_neighbors and variant_info.get("k_neighbors") != k_neighbors:
matches = False
if metric and variant_info.get("metric").lower() != metric.lower():
matches = False
if training_version and variant_info.get("training_version") != training_version:
matches = False
if matches:
return variant_info["filename"]
# No match found
raise ValueError(
f"No model variant found matching: data_source={data_source}, "
f"k_neighbors={k_neighbors}, metric={metric}, training_version={training_version}"
)
except Exception as e:
# Fallback: construct filename from parameters
if not data_source or not k_neighbors or not metric:
raise ValueError(
"Could not load model_index.json and insufficient parameters provided. "
"Please specify: data_source, k_neighbors, and metric"
) from e
# Construct filename
ds = data_source.replace("-", "") # "7T-21T" -> "7T21T"
version_suffix = f"_{training_version}" if training_version else ""
filename = f"models/knn_{ds}_k{k_neighbors}_{metric.lower()}{version_suffix}.joblib"
return filename
@staticmethod
def _resolve_ensemble_filenames(
pretrained_model_name_or_path: str,
variant: Optional[str] = None,
data_source: Optional[str] = None,
k_neighbors: Optional[int] = None,
metric: Optional[str] = None,
) -> List[str]:
"""
Resolve ensemble model filenames (multiple .joblib files for one logical model).
For 7T-21T: returns 2 filenames (ver2 and ver3)
For Synthetic: returns 3 filenames (ver2, ver3, synthetic_data)
Returns:
List of filenames to load
"""
if not data_source:
raise ValueError("data_source is required for ensemble models")
if data_source not in ["7T-21T", "Synthetic"]:
raise ValueError(f"data_source '{data_source}' is not an ensemble model")
if not k_neighbors or not metric:
raise ValueError("k_neighbors and metric are required for ensemble models")
# Define ensemble members for each type
if data_source == "7T-21T":
training_versions = ["DOM_training_set_ver2", "DOM_training_set_ver3"]
elif data_source == "Synthetic":
training_versions = ["DOM_training_set_ver2", "DOM_training_set_ver3", "synthetic_data"]
else:
raise ValueError(f"Unknown ensemble type: {data_source}")
# Construct filenames based on original naming pattern
# Pattern: knn_model_Model-{data_source}_K{k}_{Metric}_{training_version}.joblib
metric_name = metric.capitalize() # "euclidean" -> "Euclidean"
filenames = []
for version in training_versions:
filename = f"models/knn_model_Model-{data_source}_K{k_neighbors}_{metric_name}_{version}.joblib"
filenames.append(filename)
return filenames