File size: 3,158 Bytes
f481275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# predictor.py
import os
import torch
import numpy as np
import joblib
from model import create_model_30day


SCALER_PATHS = {
    "texas": "scaler_texas.joblib",
    "china": "scaler_china.joblib",
    "ethiopia": "scaler_ethiopia.joblib",
}


class MineROIPredictor:
    def __init__(self, model_path, device=None):
        self.window_size = 30
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.class_names = [
            'Unprofitable (ROI ≤ 0)',
            'Marginal (0 < ROI < 1)',
            'Profitable (ROI ≥ 1)'
        ]

        # ✅ Load all scalers that were used in preprocessing
        self.scalers = {}
        for region, path in SCALER_PATHS.items():
            if not os.path.exists(path):
                raise FileNotFoundError(f"Scaler not found for {region}: {path}")
            self.scalers[region] = joblib.load(path)

        # Load model weights
        state_dict = torch.load(model_path, map_location=self.device)

        # Infer input_dim from spectral layer weights
        # (works because you saved the full state_dict from training)
        self.input_dim = state_dict['spectral.complex_weight'].shape[0]

        # Build model with same hyperparams used in training
        self.model = create_model_30day(self.input_dim)
        self.model.load_state_dict(state_dict)
        self.model.to(self.device)
        self.model.eval()

    def normalize_sequence(self, sequence: np.ndarray, region: str) -> np.ndarray:
        """
        sequence: shape (L, C)
        region: 'texas', 'china', or 'ethiopia'
        """
        if region not in self.scalers:
            raise ValueError(f"Unknown region '{region}'. Expected one of {list(self.scalers.keys())}")

        scaler = self.scalers[region]

        original_shape = sequence.shape   # (L, C)
        seq_2d = sequence.reshape(-1, original_shape[-1])

        # ✅ Only transform, never fit here
        seq_scaled = scaler.transform(seq_2d)

        return seq_scaled.reshape(original_shape)

    def predict(self, sequence: np.ndarray, region: str):
        """
        sequence: np.ndarray of shape (L, C) with *raw* features (same as training CSV)
        region: which country scaler to use
        """
        # 1) scale using the correct country’s scaler
        sequence = self.normalize_sequence(sequence, region)

        # 2) to torch: [B, C, L]
        seq_tensor = torch.from_numpy(sequence).float().unsqueeze(0).to(self.device)  # (1, L, C)

        with torch.no_grad():
            logits = self.model(seq_tensor)
            probabilities = torch.softmax(logits, dim=1)
            predicted_class = torch.argmax(probabilities, dim=1).item()

        probs = probabilities.cpu().numpy()[0]

        return {
            "predicted_class": predicted_class,
            "predicted_label": self.class_names[predicted_class],
            "probabilities": {
                "unprofitable": float(probs[0]),
                "marginal": float(probs[1]),
                "profitable": float(probs[2]),
            },
            "confidence": float(probs[predicted_class]),
        }