mathaisjustin's picture
Deploy Fish Disease Detection AI
fbbdeab
"""
Predictor - Orchestrates the complete prediction pipeline
Combines validation, model inference, and treatment generation
"""
from .validator import FishImageValidator
from .model_loader import VGG16ModelLoader
from .treatment import TreatmentGenerator
from .config import CLASSES, CONFIDENCE_THRESHOLD
class FishDiseasePredictor:
"""Main prediction pipeline that coordinates all components"""
def __init__(self, config, gemini_model=None):
"""
Initialize predictor with all components
Args:
config: Configuration dictionary with all settings
gemini_model: Google Gemini model instance (optional)
"""
self.config = config
self.gemini_model = gemini_model
self.confidence_threshold = config.get('CONFIDENCE_THRESHOLD', CONFIDENCE_THRESHOLD)
# Initialize all components
self.validator = FishImageValidator(
max_size_mb=config.get('MAX_FILE_SIZE_MB', 10),
min_size_px=config.get('MIN_IMAGE_SIZE_PX', 100),
valid_extensions=config.get('VALID_EXTENSIONS')
)
self.model_loader = VGG16ModelLoader(
model_path=config['MODEL_PATH'],
num_classes=len(config['CLASSES']),
device=config.get('DEVICE', 'cpu')
)
self.treatment_generator = TreatmentGenerator(gemini_model)
def predict(self, image_path):
"""
Complete prediction pipeline with all validations
Args:
image_path: Path to image file
Returns:
dict: Complete result with validation, prediction, and treatment
{
'success': bool,
'error': str or None,
'validation': {
'file': {'valid': bool, 'message': str},
'gemini': {'valid': bool, 'message': str}
},
'prediction': {
'disease': str,
'confidence': float,
'probabilities': dict,
'below_threshold': bool
},
'treatment': str or None
}
"""
result = {
'success': False,
'error': None,
'validation': {},
'prediction': {},
'treatment': None
}
# STEP 1: File validation (format, size, corruption)
print("πŸ” Validating file...")
is_valid, msg, image = self.validator.validate_file(image_path)
result['validation']['file'] = {'valid': is_valid, 'message': msg}
if not is_valid:
result['error'] = msg
return result
# STEP 2: Gemini AI validation (fish detection, edge cases)
print("πŸ€– Validating with Gemini AI...")
is_valid, msg = self.validator.validate_with_gemini(image, self.gemini_model)
result['validation']['gemini'] = {'valid': is_valid, 'message': msg}
print(msg)
if not is_valid:
result['error'] = msg
return result
# STEP 3: VGG16 disease prediction
print("πŸ”¬ Analyzing fish health...")
try:
class_idx, confidence, probabilities = self.model_loader.predict(image)
predicted_class = self.config['CLASSES'][class_idx]
# Build probability dictionary
prob_dict = {
self.config['CLASSES'][i]: float(probabilities[i].item() * 100)
for i in range(len(self.config['CLASSES']))
}
result['prediction'] = {
'disease': predicted_class,
'confidence': confidence,
'probabilities': prob_dict,
'below_threshold': confidence < self.confidence_threshold
}
# STEP 4: Generate treatment (only if confident)
if confidence >= self.confidence_threshold:
print("πŸ’Š Generating treatment recommendations...")
treatment = self.treatment_generator.get_recommendations(
predicted_class, confidence
)
result['treatment'] = treatment
result['success'] = True
return result
except Exception as e:
result['error'] = f"❌ Prediction failed: {str(e)}"
return result
def predict_from_image(self, pil_image):
"""
Predict directly from PIL Image (for web UI)
Args:
pil_image: PIL Image object
Returns:
dict: Same as predict() method
"""
result = {
'success': False,
'error': None,
'validation': {},
'prediction': {},
'treatment': None
}
try:
# Convert to RGB if needed
image = pil_image.convert('RGB')
# File validation not needed (already have image)
result['validation']['file'] = {'valid': True, 'message': 'βœ… Image provided'}
# Gemini validation
print("πŸ€– Validating with Gemini AI...")
is_valid, msg = self.validator.validate_with_gemini(image, self.gemini_model)
result['validation']['gemini'] = {'valid': is_valid, 'message': msg}
print(msg)
if not is_valid:
result['error'] = msg
return result
# Model prediction
print("πŸ”¬ Analyzing fish health...")
class_idx, confidence, probabilities = self.model_loader.predict(image)
predicted_class = self.config['CLASSES'][class_idx]
prob_dict = {
self.config['CLASSES'][i]: float(probabilities[i].item() * 100)
for i in range(len(self.config['CLASSES']))
}
result['prediction'] = {
'disease': predicted_class,
'confidence': confidence,
'probabilities': prob_dict,
'below_threshold': confidence < self.confidence_threshold
}
# Generate treatment
if confidence >= self.confidence_threshold:
print("πŸ’Š Generating treatment recommendations...")
treatment = self.treatment_generator.get_recommendations(
predicted_class, confidence
)
result['treatment'] = treatment
result['success'] = True
return result
except Exception as e:
result['error'] = f"❌ Prediction failed: {str(e)}"
return result