|
|
import os |
|
|
import os |
|
|
import json |
|
|
from typing import Any, Dict, Optional, List |
|
|
|
|
|
import joblib |
|
|
from transformers import PreTrainedModel |
|
|
|
|
|
from .configuration_knn import KNNConfig |
|
|
|
|
|
|
|
|
class KNNModel(PreTrainedModel): |
|
|
""" |
|
|
A tiny wrapper so an sklearn KNN (joblib) can be saved/loaded with |
|
|
the transformers save_pretrained / from_pretrained pattern. |
|
|
|
|
|
Notes: |
|
|
- We persist the sklearn object as `model.joblib` inside the folder. |
|
|
- Loading from the Hub via `transformers` will require |
|
|
`trust_remote_code=True` or using this module locally. |
|
|
""" |
|
|
|
|
|
config_class = KNNConfig |
|
|
base_model_prefix = "knn" |
|
|
|
|
|
def __init__(self, config: KNNConfig, model: Optional[Any] = None, models: Optional[List] = None): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
self.knn = model |
|
|
|
|
|
self.models = models or [] |
|
|
self.is_ensemble = config.is_ensemble or len(self.models) > 1 |
|
|
|
|
|
def forward(self, X, **kwargs): |
|
|
"""Return predictions for an input array-like X. |
|
|
|
|
|
For ensemble models, uses the first model's predictions. |
|
|
(You can implement voting/averaging logic here if desired) |
|
|
|
|
|
This is intentionally simple; you can adapt to return ModelOutput |
|
|
structured objects if desired. |
|
|
""" |
|
|
if self.is_ensemble and self.models: |
|
|
|
|
|
return self.models[0].predict(X) |
|
|
elif self.knn is not None: |
|
|
return self.knn.predict(X) |
|
|
else: |
|
|
raise ValueError("Model not loaded. Call from_pretrained or load a joblib model first.") |
|
|
|
|
|
def save_pretrained(self, save_directory: str, **kwargs) -> None: |
|
|
""" |
|
|
Save only the config and the sklearn object(s). |
|
|
|
|
|
We intentionally avoid calling the parent `save_pretrained` because the |
|
|
transformers implementation expects a PyTorch model (and tries to infer |
|
|
a `dtype` from model tensors), which fails for non-torch objects and |
|
|
raises the IndexError seen in CI/when running locally. Instead we use |
|
|
the config's `save_pretrained` method and persist the sklearn object |
|
|
as `model.joblib` (or multiple files for ensembles). |
|
|
""" |
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
|
|
|
self.config.save_pretrained(save_directory) |
|
|
|
|
|
|
|
|
if self.is_ensemble and self.models: |
|
|
|
|
|
for i, (member_name, model_obj) in enumerate(zip(self.config.ensemble_members, self.models)): |
|
|
out_path = os.path.join(save_directory, member_name) |
|
|
os.makedirs(os.path.dirname(out_path), exist_ok=True) |
|
|
joblib.dump(model_obj, out_path) |
|
|
elif self.knn is not None: |
|
|
out_path = os.path.join(save_directory, "model.joblib") |
|
|
joblib.dump(self.knn, out_path) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs): |
|
|
""" |
|
|
Load a KNN model with optional variant selection. |
|
|
|
|
|
Supports two modes: |
|
|
1. Direct loading: loads model.joblib from the specified path/repo |
|
|
2. Variant selection: specify parameters to auto-select a model variant |
|
|
|
|
|
For ensemble models (7T-21T, Synthetic), automatically loads all sub-models. |
|
|
|
|
|
Args: |
|
|
pretrained_model_name_or_path: Local path or HF Hub repo ID |
|
|
data_source: Optional. One of: "7T", "21T", "7T-21T", "Synthetic" |
|
|
k_neighbors: Optional. 1 or 3 |
|
|
metric: Optional. "euclidean" or "manhattan" |
|
|
training_version: Optional. For single models only, ignored for ensembles |
|
|
variant: Optional. Direct variant name (e.g., "knn_21T_k1_euclidean") |
|
|
|
|
|
Examples: |
|
|
# Load default best model |
|
|
model = KNNModel.from_pretrained("SaeedLab/dom-formula-assignment-using-knn") |
|
|
|
|
|
# Load specific variant by parameters |
|
|
model = KNNModel.from_pretrained( |
|
|
"SaeedLab/dom-formula-assignment-using-knn", |
|
|
data_source="7T-21T", # This is an ensemble! |
|
|
k_neighbors=1, |
|
|
metric="euclidean" |
|
|
) |
|
|
|
|
|
# Load by variant name |
|
|
model = KNNModel.from_pretrained( |
|
|
"SaeedLab/dom-formula-assignment-using-knn", |
|
|
variant="knn_21T_k3_manhattan" |
|
|
) |
|
|
""" |
|
|
|
|
|
data_source = kwargs.pop("data_source", None) |
|
|
k_neighbors = kwargs.pop("k_neighbors", None) |
|
|
metric = kwargs.pop("metric", None) |
|
|
training_version = kwargs.pop("training_version", None) |
|
|
variant = kwargs.pop("variant", None) |
|
|
|
|
|
|
|
|
is_ensemble = data_source in ["7T-21T", "Synthetic"] if data_source else False |
|
|
|
|
|
|
|
|
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs) |
|
|
|
|
|
|
|
|
if k_neighbors is not None: |
|
|
config.n_neighbors = k_neighbors |
|
|
if metric is not None: |
|
|
config.metric = metric |
|
|
if data_source is not None: |
|
|
config.data_source = data_source |
|
|
if training_version is not None: |
|
|
config.training_version = training_version |
|
|
|
|
|
if is_ensemble: |
|
|
|
|
|
model_filenames = cls._resolve_ensemble_filenames( |
|
|
pretrained_model_name_or_path, |
|
|
variant=variant, |
|
|
data_source=data_source, |
|
|
k_neighbors=k_neighbors, |
|
|
metric=metric, |
|
|
) |
|
|
|
|
|
config.is_ensemble = True |
|
|
config.ensemble_members = model_filenames |
|
|
|
|
|
models = [] |
|
|
for model_filename in model_filenames: |
|
|
model_file = os.path.join(pretrained_model_name_or_path, model_filename) |
|
|
if os.path.exists(model_file): |
|
|
knn = joblib.load(model_file) |
|
|
else: |
|
|
|
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
repo_id = pretrained_model_name_or_path |
|
|
model_path = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
filename=model_filename, |
|
|
**kwargs.get("hub_kwargs", {}) |
|
|
) |
|
|
knn = joblib.load(model_path) |
|
|
except Exception as exc: |
|
|
raise RuntimeError( |
|
|
f"Could not locate or download {model_filename} for {pretrained_model_name_or_path}: {exc}" |
|
|
) |
|
|
models.append(knn) |
|
|
|
|
|
inst = cls(config=config, models=models) |
|
|
return inst |
|
|
else: |
|
|
|
|
|
model_filename = cls._resolve_model_filename( |
|
|
pretrained_model_name_or_path, |
|
|
variant=variant, |
|
|
data_source=data_source, |
|
|
k_neighbors=k_neighbors, |
|
|
metric=metric, |
|
|
training_version=training_version, |
|
|
) |
|
|
|
|
|
config.is_ensemble = False |
|
|
|
|
|
|
|
|
model_file = os.path.join(pretrained_model_name_or_path, model_filename) |
|
|
if os.path.exists(model_file): |
|
|
knn = joblib.load(model_file) |
|
|
else: |
|
|
|
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
repo_id = pretrained_model_name_or_path |
|
|
model_path = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
filename=model_filename, |
|
|
**kwargs.get("hub_kwargs", {}) |
|
|
) |
|
|
knn = joblib.load(model_path) |
|
|
except Exception as exc: |
|
|
raise RuntimeError( |
|
|
f"Could not locate or download {model_filename} for {pretrained_model_name_or_path}: {exc}" |
|
|
) |
|
|
|
|
|
inst = cls(config=config, model=knn) |
|
|
return inst |
|
|
|
|
|
@staticmethod |
|
|
def _resolve_model_filename( |
|
|
pretrained_model_name_or_path: str, |
|
|
variant: Optional[str] = None, |
|
|
data_source: Optional[str] = None, |
|
|
k_neighbors: Optional[int] = None, |
|
|
metric: Optional[str] = None, |
|
|
training_version: Optional[str] = None, |
|
|
) -> str: |
|
|
""" |
|
|
Resolve the model filename based on variant parameters. |
|
|
|
|
|
Returns: |
|
|
Filename of the .joblib model to load (e.g., "models/knn_21T_k1_euclidean.joblib") |
|
|
""" |
|
|
|
|
|
if variant: |
|
|
|
|
|
if not variant.endswith(".joblib"): |
|
|
variant = f"{variant}.joblib" |
|
|
|
|
|
if not variant.startswith("models/"): |
|
|
return f"models/{variant}" |
|
|
return variant |
|
|
|
|
|
|
|
|
if not any([data_source, k_neighbors, metric, training_version]): |
|
|
return "models/knn_21T_k1_euclidean.joblib" |
|
|
|
|
|
|
|
|
try: |
|
|
index_path = os.path.join(pretrained_model_name_or_path, "model_index.json") |
|
|
if os.path.exists(index_path): |
|
|
with open(index_path, "r") as f: |
|
|
index = json.load(f) |
|
|
else: |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
index_path = hf_hub_download( |
|
|
repo_id=pretrained_model_name_or_path, |
|
|
filename="model_index.json" |
|
|
) |
|
|
with open(index_path, "r") as f: |
|
|
index = json.load(f) |
|
|
|
|
|
|
|
|
for variant_name, variant_info in index.get("variants", {}).items(): |
|
|
matches = True |
|
|
if data_source and variant_info.get("data_source") != data_source: |
|
|
matches = False |
|
|
if k_neighbors and variant_info.get("k_neighbors") != k_neighbors: |
|
|
matches = False |
|
|
if metric and variant_info.get("metric").lower() != metric.lower(): |
|
|
matches = False |
|
|
if training_version and variant_info.get("training_version") != training_version: |
|
|
matches = False |
|
|
|
|
|
if matches: |
|
|
return variant_info["filename"] |
|
|
|
|
|
|
|
|
raise ValueError( |
|
|
f"No model variant found matching: data_source={data_source}, " |
|
|
f"k_neighbors={k_neighbors}, metric={metric}, training_version={training_version}" |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
if not data_source or not k_neighbors or not metric: |
|
|
raise ValueError( |
|
|
"Could not load model_index.json and insufficient parameters provided. " |
|
|
"Please specify: data_source, k_neighbors, and metric" |
|
|
) from e |
|
|
|
|
|
|
|
|
ds = data_source.replace("-", "") |
|
|
version_suffix = f"_{training_version}" if training_version else "" |
|
|
filename = f"models/knn_{ds}_k{k_neighbors}_{metric.lower()}{version_suffix}.joblib" |
|
|
return filename |
|
|
|
|
|
@staticmethod |
|
|
def _resolve_ensemble_filenames( |
|
|
pretrained_model_name_or_path: str, |
|
|
variant: Optional[str] = None, |
|
|
data_source: Optional[str] = None, |
|
|
k_neighbors: Optional[int] = None, |
|
|
metric: Optional[str] = None, |
|
|
) -> List[str]: |
|
|
""" |
|
|
Resolve ensemble model filenames (multiple .joblib files for one logical model). |
|
|
|
|
|
For 7T-21T: returns 2 filenames (ver2 and ver3) |
|
|
For Synthetic: returns 3 filenames (ver2, ver3, synthetic_data) |
|
|
|
|
|
Returns: |
|
|
List of filenames to load |
|
|
""" |
|
|
if not data_source: |
|
|
raise ValueError("data_source is required for ensemble models") |
|
|
|
|
|
if data_source not in ["7T-21T", "Synthetic"]: |
|
|
raise ValueError(f"data_source '{data_source}' is not an ensemble model") |
|
|
|
|
|
if not k_neighbors or not metric: |
|
|
raise ValueError("k_neighbors and metric are required for ensemble models") |
|
|
|
|
|
|
|
|
if data_source == "7T-21T": |
|
|
training_versions = ["DOM_training_set_ver2", "DOM_training_set_ver3"] |
|
|
elif data_source == "Synthetic": |
|
|
training_versions = ["DOM_training_set_ver2", "DOM_training_set_ver3", "synthetic_data"] |
|
|
else: |
|
|
raise ValueError(f"Unknown ensemble type: {data_source}") |
|
|
|
|
|
|
|
|
|
|
|
metric_name = metric.capitalize() |
|
|
filenames = [] |
|
|
for version in training_versions: |
|
|
filename = f"models/knn_model_Model-{data_source}_K{k_neighbors}_{metric_name}_{version}.joblib" |
|
|
filenames.append(filename) |
|
|
|
|
|
return filenames |
|
|
|
|
|
|