Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| # In[1]: | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import cv2 | |
| import numpy as np | |
| from torchvision import transforms | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import streamlit as st | |
| # In[2]: | |
| def preprocess_image(image_path): | |
| """ | |
| Load and preprocess an image for inference. | |
| """ | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| img = Image.open(image_path).convert('RGB') | |
| tensor = transform(img) | |
| return tensor.unsqueeze(0), img | |
| # In[3]: | |
| def get_last_conv_layer(model): | |
| """ | |
| Get the last convolutional layer in the model. | |
| """ | |
| # For ResNet architecture | |
| for name, module in reversed(list(model.named_modules())): | |
| if isinstance(module, nn.Conv2d): | |
| return name | |
| raise ValueError("No Conv2d layers found in the model.") | |
| # In[4]: | |
| def apply_gradcam(model, image_tensor, target_class=None): | |
| """ | |
| Apply Grad-CAM to an image. | |
| """ | |
| device = next(model.parameters()).device | |
| image_tensor = image_tensor.to(device) | |
| # Register hooks to get activations and gradients | |
| features = [] | |
| gradients = [] | |
| def forward_hook(module, input, output): | |
| features.append(output.detach()) | |
| def backward_hook(module, grad_input, grad_output): | |
| gradients.append(grad_output[0].detach()) | |
| last_conv_layer_name = get_last_conv_layer(model) | |
| last_conv_layer = dict(model.named_modules())[last_conv_layer_name] | |
| handle_forward = last_conv_layer.register_forward_hook(forward_hook) | |
| handle_backward = last_conv_layer.register_full_backward_hook(backward_hook) | |
| # Forward pass | |
| model.eval() | |
| output = model(image_tensor) | |
| if target_class is None: | |
| target_class = output.argmax(dim=1).item() | |
| # Zero out all gradients | |
| model.zero_grad() | |
| # Backward pass | |
| one_hot = torch.zeros_like(output) | |
| one_hot[0][target_class] = 1 | |
| output.backward(gradient=one_hot) | |
| # Remove hooks | |
| handle_forward.remove() | |
| handle_backward.remove() | |
| # Get feature maps and gradients | |
| feature_map = features[-1].squeeze().cpu().numpy() | |
| gradient = gradients[-1].squeeze().cpu().numpy() | |
| # Global Average Pooling on gradients | |
| pooled_gradients = np.mean(gradient, axis=(1, 2), keepdims=True) | |
| cam = feature_map * pooled_gradients | |
| cam = np.sum(cam, axis=0) | |
| # Apply ReLU | |
| cam = np.maximum(cam, 0) | |
| # Normalize the CAM | |
| cam = cam - np.min(cam) | |
| cam = cam / np.max(cam) | |
| # Resize CAM to match the original image size | |
| cam = cv2.resize(cam, (224, 224)) | |
| return cam | |
| # In[5]: | |
| def overlay_heatmap(original_image, heatmap, alpha=0.5): | |
| """ | |
| Overlay the heatmap on the original image. | |
| Args: | |
| original_image (np.ndarray): Original image (H, W, 3), uint8 | |
| heatmap (np.ndarray): Grad-CAM heatmap (H', W'), float between 0 and 1 | |
| alpha (float): Weight for the heatmap | |
| Returns: | |
| np.ndarray: Overlayed image | |
| """ | |
| # Ensure heatmap is 2D | |
| if heatmap.ndim == 3: | |
| heatmap = np.mean(heatmap, axis=2) | |
| # Resize heatmap to match original image size | |
| heatmap_resized = cv2.resize(heatmap, (original_image.shape[1], original_image.shape[0])) | |
| # Normalize heatmap to [0, 255] | |
| heatmap_resized = np.uint8(255 * heatmap_resized) | |
| # Apply colormap | |
| heatmap_colored = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET) | |
| # Convert from BGR to RGB | |
| heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB) | |
| # Superimpose: blend heatmap and original image | |
| superimposed_img = heatmap_colored * alpha + original_image * (1 - alpha) | |
| return np.uint8(superimposed_img) | |
| def visualize_gradcam(model, image_path): | |
| """ | |
| Visualize Grad-CAM for a given image. | |
| """ | |
| # Preprocess image | |
| image_tensor, original_image = preprocess_image(image_path) | |
| original_image_np = np.array(original_image) # PIL -> numpy array | |
| # Resize original image for better display | |
| max_size = (400, 400) # Max width and height | |
| original_image_resized = cv2.resize(original_image_np, max_size) | |
| # Apply Grad-CAM | |
| cam = apply_gradcam(model, image_tensor) | |
| # Resize CAM to match original image size | |
| heatmap_resized = cv2.resize(cam, (original_image_np.shape[1], original_image_np.shape[0])) | |
| # Normalize heatmap to [0, 255] | |
| heatmap_resized = np.uint8(255 * heatmap_resized / np.max(heatmap_resized)) | |
| # Apply color map | |
| heatmap_colored = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET) | |
| heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB) | |
| # Overlay | |
| superimposed_img = heatmap_colored * 0.4 + original_image_np * 0.6 | |
| superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8) | |
| # Display results | |
| fig, axes = plt.subplots(1, 2, figsize=(8, 4)) # Adjust figsize as needed | |
| axes[0].imshow(original_image_resized) | |
| axes[0].set_title("Original Image") | |
| axes[0].axis("off") | |
| axes[1].imshow(superimposed_img) | |
| axes[1].set_title("Grad-CAM Heatmap") | |
| axes[1].axis("off") | |
| plt.tight_layout() | |
| st.pyplot(fig) | |
| plt.close(fig) | |
| # In[6]: | |
| if __name__ == "__main__": | |
| from models.resnet_model import MalariaResNet50 | |
| # Load your trained model | |
| model = MalariaResNet50(num_classes=2) | |
| model.load_state_dict(torch.load("models/malaria_model.pth")) | |
| model.eval() | |
| # Path to an image | |
| image_path = "malaria_ds/split_dataset/test/Parasitized/C33P1thinF_IMG_20150619_114756a_cell_181.png" | |
| # Visualize Grad-CAM | |
| visualize_gradcam(model, image_path) | |
| # In[ ]: | |