File size: 1,105 Bytes
520c582 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
from transformers import PretrainedConfig
from typing import List, Optional
class KNNConfig(PretrainedConfig):
"""
Minimal Transformers-style config for a scikit-learn KNN model.
This stores only metadata needed to describe the model on the Hub.
For ensemble models (7T-21T, Synthetic), is_ensemble=True and
ensemble_members lists the sub-model filenames.
"""
model_type = "knn"
def __init__(
self,
n_neighbors: int = 3,
metric: str = "euclidean",
feature_names: Optional[List[str]] = None,
is_ensemble: bool = False,
ensemble_members: Optional[List[str]] = None,
data_source: Optional[str] = None,
training_version: Optional[str] = None,
**kwargs,
):
self.n_neighbors = n_neighbors
self.metric = metric
self.feature_names = feature_names or []
self.is_ensemble = is_ensemble
self.ensemble_members = ensemble_members or []
self.data_source = data_source
self.training_version = training_version
super().__init__(**kwargs)
|