File size: 13,988 Bytes
4eebf40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3c503f
4eebf40
 
 
f3c503f
4eebf40
 
 
 
 
 
 
f3c503f
4eebf40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
import os
import os
import json
from typing import Any, Dict, Optional, List

import joblib
from transformers import PreTrainedModel

from .configuration_knn import KNNConfig


class KNNModel(PreTrainedModel):
    """
    A tiny wrapper so an sklearn KNN (joblib) can be saved/loaded with
    the transformers save_pretrained / from_pretrained pattern.

    Notes:
    - We persist the sklearn object as `model.joblib` inside the folder.
    - Loading from the Hub via `transformers` will require
      `trust_remote_code=True` or using this module locally.
    """

    config_class = KNNConfig
    base_model_prefix = "knn"

    def __init__(self, config: KNNConfig, model: Optional[Any] = None, models: Optional[List] = None):
        super().__init__(config)
        # self.knn is the actual sklearn KNN object (e.g., sklearn.neighbors.KNeighborsClassifier)
        # for single models
        self.knn = model
        # self.models is a list of sklearn KNN objects for ensemble models
        self.models = models or []
        self.is_ensemble = config.is_ensemble or len(self.models) > 1

    def forward(self, X, **kwargs):
        """Return predictions for an input array-like X.

        For ensemble models, uses the first model's predictions.
        (You can implement voting/averaging logic here if desired)
        
        This is intentionally simple; you can adapt to return ModelOutput
        structured objects if desired.
        """
        if self.is_ensemble and self.models:
            # Use first model for now; could implement ensemble voting
            return self.models[0].predict(X)
        elif self.knn is not None:
            return self.knn.predict(X)
        else:
            raise ValueError("Model not loaded. Call from_pretrained or load a joblib model first.")

    def save_pretrained(self, save_directory: str, **kwargs) -> None:
        """
        Save only the config and the sklearn object(s).

        We intentionally avoid calling the parent `save_pretrained` because the
        transformers implementation expects a PyTorch model (and tries to infer
        a `dtype` from model tensors), which fails for non-torch objects and
        raises the IndexError seen in CI/when running locally. Instead we use
        the config's `save_pretrained` method and persist the sklearn object
        as `model.joblib` (or multiple files for ensembles).
        """
        os.makedirs(save_directory, exist_ok=True)
        # save config.json (PretrainedConfig handles this)
        self.config.save_pretrained(save_directory)

        # persist sklearn object(s) as joblib
        if self.is_ensemble and self.models:
            # Save each ensemble member with its original filename
            for i, (member_name, model_obj) in enumerate(zip(self.config.ensemble_members, self.models)):
                out_path = os.path.join(save_directory, member_name)
                os.makedirs(os.path.dirname(out_path), exist_ok=True)
                joblib.dump(model_obj, out_path)
        elif self.knn is not None:
            out_path = os.path.join(save_directory, "model.joblib")
            joblib.dump(self.knn, out_path)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):
        """
        Load a KNN model with optional variant selection.

        Supports two modes:
        1. Direct loading: loads model.joblib from the specified path/repo
        2. Variant selection: specify parameters to auto-select a model variant

        For ensemble models (7T-21T, Synthetic), automatically loads all sub-models.

        Args:
            pretrained_model_name_or_path: Local path or HF Hub repo ID
            data_source: Optional. One of: "7T", "21T", "7T-21T", "Synthetic"
            k_neighbors: Optional. 1 or 3
            metric: Optional. "euclidean" or "manhattan"
            training_version: Optional. For single models only, ignored for ensembles
            variant: Optional. Direct variant name (e.g., "knn_21T_k1_euclidean")

        Examples:
            # Load default best model
            model = KNNModel.from_pretrained("SaeedLab/dom-formula-assignment-using-knn")

            # Load specific variant by parameters
            model = KNNModel.from_pretrained(
                "SaeedLab/dom-formula-assignment-using-knn",
                data_source="7T-21T",  # This is an ensemble!
                k_neighbors=1,
                metric="euclidean"
            )

            # Load by variant name
            model = KNNModel.from_pretrained(
                "SaeedLab/dom-formula-assignment-using-knn",
                variant="knn_21T_k3_manhattan"
            )
        """
        # Extract variant selection parameters
        data_source = kwargs.pop("data_source", None)
        k_neighbors = kwargs.pop("k_neighbors", None)
        metric = kwargs.pop("metric", None)
        training_version = kwargs.pop("training_version", None)
        variant = kwargs.pop("variant", None)

        # Determine if this is an ensemble model
        is_ensemble = data_source in ["7T-21T", "Synthetic"] if data_source else False

        # load config using parent machinery (handles repo id or local path)
        config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)

        # Update config with variant info if provided
        if k_neighbors is not None:
            config.n_neighbors = k_neighbors
        if metric is not None:
            config.metric = metric
        if data_source is not None:
            config.data_source = data_source
        if training_version is not None:
            config.training_version = training_version

        if is_ensemble:
            # Load ensemble model (multiple joblib files)
            model_filenames = cls._resolve_ensemble_filenames(
                pretrained_model_name_or_path,
                variant=variant,
                data_source=data_source,
                k_neighbors=k_neighbors,
                metric=metric,
            )
            
            config.is_ensemble = True
            config.ensemble_members = model_filenames
            
            models = []
            for model_filename in model_filenames:
                model_file = os.path.join(pretrained_model_name_or_path, model_filename)
                if os.path.exists(model_file):
                    knn = joblib.load(model_file)
                else:
                    # try to download from hub
                    try:
                        from huggingface_hub import hf_hub_download
                        repo_id = pretrained_model_name_or_path
                        model_path = hf_hub_download(
                            repo_id=repo_id,
                            filename=model_filename,
                            **kwargs.get("hub_kwargs", {})
                        )
                        knn = joblib.load(model_path)
                    except Exception as exc:
                        raise RuntimeError(
                            f"Could not locate or download {model_filename} for {pretrained_model_name_or_path}: {exc}"
                        )
                models.append(knn)
            
            inst = cls(config=config, models=models)
            return inst
        else:
            # Load single model
            model_filename = cls._resolve_model_filename(
                pretrained_model_name_or_path,
                variant=variant,
                data_source=data_source,
                k_neighbors=k_neighbors,
                metric=metric,
                training_version=training_version,
            )

            config.is_ensemble = False
            
            # Attempt to resolve model file
            model_file = os.path.join(pretrained_model_name_or_path, model_filename)
            if os.path.exists(model_file):
                knn = joblib.load(model_file)
            else:
                # try to download from hub
                try:
                    from huggingface_hub import hf_hub_download

                    repo_id = pretrained_model_name_or_path
                    model_path = hf_hub_download(
                        repo_id=repo_id,
                        filename=model_filename,
                        **kwargs.get("hub_kwargs", {})
                    )
                    knn = joblib.load(model_path)
                except Exception as exc:
                    raise RuntimeError(
                        f"Could not locate or download {model_filename} for {pretrained_model_name_or_path}: {exc}"
                    )

            inst = cls(config=config, model=knn)
            return inst

    @staticmethod
    def _resolve_model_filename(
        pretrained_model_name_or_path: str,
        variant: Optional[str] = None,
        data_source: Optional[str] = None,
        k_neighbors: Optional[int] = None,
        metric: Optional[str] = None,
        training_version: Optional[str] = None,
    ) -> str:
        """
        Resolve the model filename based on variant parameters.

        Returns:
            Filename of the .joblib model to load (e.g., "models/knn_21T_k1_euclidean.joblib")
        """
        # If direct variant name provided, use it
        if variant:
            # Ensure .joblib extension
            if not variant.endswith(".joblib"):
                variant = f"{variant}.joblib"
            # Check if it needs models/ prefix
            if not variant.startswith("models/"):
                return f"models/{variant}"
            return variant

        # If no parameters provided, use default (best performing model)
        if not any([data_source, k_neighbors, metric, training_version]):
            return "models/knn_21T_k1_euclidean.joblib"

        # Try to load model index to find matching variant
        try:
            index_path = os.path.join(pretrained_model_name_or_path, "model_index.json")
            if os.path.exists(index_path):
                with open(index_path, "r") as f:
                    index = json.load(f)
            else:
                # Try to download from hub
                from huggingface_hub import hf_hub_download
                index_path = hf_hub_download(
                    repo_id=pretrained_model_name_or_path,
                    filename="model_index.json"
                )
                with open(index_path, "r") as f:
                    index = json.load(f)

            # Search for matching variant
            for variant_name, variant_info in index.get("variants", {}).items():
                matches = True
                if data_source and variant_info.get("data_source") != data_source:
                    matches = False
                if k_neighbors and variant_info.get("k_neighbors") != k_neighbors:
                    matches = False
                if metric and variant_info.get("metric").lower() != metric.lower():
                    matches = False
                if training_version and variant_info.get("training_version") != training_version:
                    matches = False

                if matches:
                    return variant_info["filename"]

            # No match found
            raise ValueError(
                f"No model variant found matching: data_source={data_source}, "
                f"k_neighbors={k_neighbors}, metric={metric}, training_version={training_version}"
            )

        except Exception as e:
            # Fallback: construct filename from parameters
            if not data_source or not k_neighbors or not metric:
                raise ValueError(
                    "Could not load model_index.json and insufficient parameters provided. "
                    "Please specify: data_source, k_neighbors, and metric"
                ) from e

            # Construct filename
            ds = data_source.replace("-", "")  # "7T-21T" -> "7T21T"
            version_suffix = f"_{training_version}" if training_version else ""
            filename = f"models/knn_{ds}_k{k_neighbors}_{metric.lower()}{version_suffix}.joblib"
            return filename

    @staticmethod
    def _resolve_ensemble_filenames(
        pretrained_model_name_or_path: str,
        variant: Optional[str] = None,
        data_source: Optional[str] = None,
        k_neighbors: Optional[int] = None,
        metric: Optional[str] = None,
    ) -> List[str]:
        """
        Resolve ensemble model filenames (multiple .joblib files for one logical model).

        For 7T-21T: returns 2 filenames (ver2 and ver3)
        For Synthetic: returns 3 filenames (ver2, ver3, synthetic_data)

        Returns:
            List of filenames to load
        """
        if not data_source:
            raise ValueError("data_source is required for ensemble models")
        
        if data_source not in ["7T-21T", "Synthetic"]:
            raise ValueError(f"data_source '{data_source}' is not an ensemble model")

        if not k_neighbors or not metric:
            raise ValueError("k_neighbors and metric are required for ensemble models")

        # Define ensemble members for each type
        if data_source == "7T-21T":
            training_versions = ["DOM_training_set_ver2", "DOM_training_set_ver3"]
        elif data_source == "Synthetic":
            training_versions = ["DOM_training_set_ver2", "DOM_training_set_ver3", "synthetic_data"]
        else:
            raise ValueError(f"Unknown ensemble type: {data_source}")

        # Construct filenames based on original naming pattern
        # Pattern: knn_model_Model-{data_source}_K{k}_{Metric}_{training_version}.joblib
        metric_name = metric.capitalize()  # "euclidean" -> "Euclidean"
        filenames = []
        for version in training_versions:
            filename = f"models/knn_model_Model-{data_source}_K{k_neighbors}_{metric_name}_{version}.joblib"
            filenames.append(filename)

        return filenames