MineROI-Net / predictor.py
sithuWiki's picture
upload 7 .py files
f481275 verified
raw
history blame
3.16 kB
# predictor.py
import os
import torch
import numpy as np
import joblib
from model import create_model_30day
SCALER_PATHS = {
"texas": "scaler_texas.joblib",
"china": "scaler_china.joblib",
"ethiopia": "scaler_ethiopia.joblib",
}
class MineROIPredictor:
def __init__(self, model_path, device=None):
self.window_size = 30
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.class_names = [
'Unprofitable (ROI ≤ 0)',
'Marginal (0 < ROI < 1)',
'Profitable (ROI ≥ 1)'
]
# ✅ Load all scalers that were used in preprocessing
self.scalers = {}
for region, path in SCALER_PATHS.items():
if not os.path.exists(path):
raise FileNotFoundError(f"Scaler not found for {region}: {path}")
self.scalers[region] = joblib.load(path)
# Load model weights
state_dict = torch.load(model_path, map_location=self.device)
# Infer input_dim from spectral layer weights
# (works because you saved the full state_dict from training)
self.input_dim = state_dict['spectral.complex_weight'].shape[0]
# Build model with same hyperparams used in training
self.model = create_model_30day(self.input_dim)
self.model.load_state_dict(state_dict)
self.model.to(self.device)
self.model.eval()
def normalize_sequence(self, sequence: np.ndarray, region: str) -> np.ndarray:
"""
sequence: shape (L, C)
region: 'texas', 'china', or 'ethiopia'
"""
if region not in self.scalers:
raise ValueError(f"Unknown region '{region}'. Expected one of {list(self.scalers.keys())}")
scaler = self.scalers[region]
original_shape = sequence.shape # (L, C)
seq_2d = sequence.reshape(-1, original_shape[-1])
# ✅ Only transform, never fit here
seq_scaled = scaler.transform(seq_2d)
return seq_scaled.reshape(original_shape)
def predict(self, sequence: np.ndarray, region: str):
"""
sequence: np.ndarray of shape (L, C) with *raw* features (same as training CSV)
region: which country scaler to use
"""
# 1) scale using the correct country’s scaler
sequence = self.normalize_sequence(sequence, region)
# 2) to torch: [B, C, L]
seq_tensor = torch.from_numpy(sequence).float().unsqueeze(0).to(self.device) # (1, L, C)
with torch.no_grad():
logits = self.model(seq_tensor)
probabilities = torch.softmax(logits, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
probs = probabilities.cpu().numpy()[0]
return {
"predicted_class": predicted_class,
"predicted_label": self.class_names[predicted_class],
"probabilities": {
"unprofitable": float(probs[0]),
"marginal": float(probs[1]),
"profitable": float(probs[2]),
},
"confidence": float(probs[predicted_class]),
}