|
|
""" |
|
|
Predictor - Orchestrates the complete prediction pipeline |
|
|
Combines validation, model inference, and treatment generation |
|
|
""" |
|
|
|
|
|
from .validator import FishImageValidator |
|
|
from .model_loader import VGG16ModelLoader |
|
|
from .treatment import TreatmentGenerator |
|
|
from .config import CLASSES, CONFIDENCE_THRESHOLD |
|
|
|
|
|
class FishDiseasePredictor: |
|
|
"""Main prediction pipeline that coordinates all components""" |
|
|
|
|
|
def __init__(self, config, gemini_model=None): |
|
|
""" |
|
|
Initialize predictor with all components |
|
|
|
|
|
Args: |
|
|
config: Configuration dictionary with all settings |
|
|
gemini_model: Google Gemini model instance (optional) |
|
|
""" |
|
|
self.config = config |
|
|
self.gemini_model = gemini_model |
|
|
self.confidence_threshold = config.get('CONFIDENCE_THRESHOLD', CONFIDENCE_THRESHOLD) |
|
|
|
|
|
|
|
|
self.validator = FishImageValidator( |
|
|
max_size_mb=config.get('MAX_FILE_SIZE_MB', 10), |
|
|
min_size_px=config.get('MIN_IMAGE_SIZE_PX', 100), |
|
|
valid_extensions=config.get('VALID_EXTENSIONS') |
|
|
) |
|
|
|
|
|
self.model_loader = VGG16ModelLoader( |
|
|
model_path=config['MODEL_PATH'], |
|
|
num_classes=len(config['CLASSES']), |
|
|
device=config.get('DEVICE', 'cpu') |
|
|
) |
|
|
|
|
|
self.treatment_generator = TreatmentGenerator(gemini_model) |
|
|
|
|
|
def predict(self, image_path): |
|
|
""" |
|
|
Complete prediction pipeline with all validations |
|
|
|
|
|
Args: |
|
|
image_path: Path to image file |
|
|
|
|
|
Returns: |
|
|
dict: Complete result with validation, prediction, and treatment |
|
|
{ |
|
|
'success': bool, |
|
|
'error': str or None, |
|
|
'validation': { |
|
|
'file': {'valid': bool, 'message': str}, |
|
|
'gemini': {'valid': bool, 'message': str} |
|
|
}, |
|
|
'prediction': { |
|
|
'disease': str, |
|
|
'confidence': float, |
|
|
'probabilities': dict, |
|
|
'below_threshold': bool |
|
|
}, |
|
|
'treatment': str or None |
|
|
} |
|
|
""" |
|
|
result = { |
|
|
'success': False, |
|
|
'error': None, |
|
|
'validation': {}, |
|
|
'prediction': {}, |
|
|
'treatment': None |
|
|
} |
|
|
|
|
|
|
|
|
print("π Validating file...") |
|
|
is_valid, msg, image = self.validator.validate_file(image_path) |
|
|
result['validation']['file'] = {'valid': is_valid, 'message': msg} |
|
|
|
|
|
if not is_valid: |
|
|
result['error'] = msg |
|
|
return result |
|
|
|
|
|
|
|
|
print("π€ Validating with Gemini AI...") |
|
|
is_valid, msg = self.validator.validate_with_gemini(image, self.gemini_model) |
|
|
result['validation']['gemini'] = {'valid': is_valid, 'message': msg} |
|
|
print(msg) |
|
|
|
|
|
if not is_valid: |
|
|
result['error'] = msg |
|
|
return result |
|
|
|
|
|
|
|
|
print("π¬ Analyzing fish health...") |
|
|
try: |
|
|
class_idx, confidence, probabilities = self.model_loader.predict(image) |
|
|
|
|
|
predicted_class = self.config['CLASSES'][class_idx] |
|
|
|
|
|
|
|
|
prob_dict = { |
|
|
self.config['CLASSES'][i]: float(probabilities[i].item() * 100) |
|
|
for i in range(len(self.config['CLASSES'])) |
|
|
} |
|
|
|
|
|
result['prediction'] = { |
|
|
'disease': predicted_class, |
|
|
'confidence': confidence, |
|
|
'probabilities': prob_dict, |
|
|
'below_threshold': confidence < self.confidence_threshold |
|
|
} |
|
|
|
|
|
|
|
|
if confidence >= self.confidence_threshold: |
|
|
print("π Generating treatment recommendations...") |
|
|
treatment = self.treatment_generator.get_recommendations( |
|
|
predicted_class, confidence |
|
|
) |
|
|
result['treatment'] = treatment |
|
|
|
|
|
result['success'] = True |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
result['error'] = f"β Prediction failed: {str(e)}" |
|
|
return result |
|
|
|
|
|
def predict_from_image(self, pil_image): |
|
|
""" |
|
|
Predict directly from PIL Image (for web UI) |
|
|
|
|
|
Args: |
|
|
pil_image: PIL Image object |
|
|
|
|
|
Returns: |
|
|
dict: Same as predict() method |
|
|
""" |
|
|
result = { |
|
|
'success': False, |
|
|
'error': None, |
|
|
'validation': {}, |
|
|
'prediction': {}, |
|
|
'treatment': None |
|
|
} |
|
|
|
|
|
try: |
|
|
|
|
|
image = pil_image.convert('RGB') |
|
|
|
|
|
|
|
|
result['validation']['file'] = {'valid': True, 'message': 'β
Image provided'} |
|
|
|
|
|
|
|
|
print("π€ Validating with Gemini AI...") |
|
|
is_valid, msg = self.validator.validate_with_gemini(image, self.gemini_model) |
|
|
result['validation']['gemini'] = {'valid': is_valid, 'message': msg} |
|
|
print(msg) |
|
|
|
|
|
if not is_valid: |
|
|
result['error'] = msg |
|
|
return result |
|
|
|
|
|
|
|
|
print("π¬ Analyzing fish health...") |
|
|
class_idx, confidence, probabilities = self.model_loader.predict(image) |
|
|
|
|
|
predicted_class = self.config['CLASSES'][class_idx] |
|
|
|
|
|
prob_dict = { |
|
|
self.config['CLASSES'][i]: float(probabilities[i].item() * 100) |
|
|
for i in range(len(self.config['CLASSES'])) |
|
|
} |
|
|
|
|
|
result['prediction'] = { |
|
|
'disease': predicted_class, |
|
|
'confidence': confidence, |
|
|
'probabilities': prob_dict, |
|
|
'below_threshold': confidence < self.confidence_threshold |
|
|
} |
|
|
|
|
|
|
|
|
if confidence >= self.confidence_threshold: |
|
|
print("π Generating treatment recommendations...") |
|
|
treatment = self.treatment_generator.get_recommendations( |
|
|
predicted_class, confidence |
|
|
) |
|
|
result['treatment'] = treatment |
|
|
|
|
|
result['success'] = True |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
result['error'] = f"β Prediction failed: {str(e)}" |
|
|
return result |
|
|
|