strelizi commited on
Commit
1082a80
Β·
verified Β·
1 Parent(s): 0a1c203

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +49 -14
  2. app.py +132 -0
  3. requirements.txt +7 -0
README.md CHANGED
@@ -1,14 +1,49 @@
1
- ---
2
- title: XAI
3
- emoji: 🐠
4
- colorFrom: pink
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 5.49.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: image analyzer
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # XAI Image Classifier
2
+
3
+ An explainable image classification web app using ResNet18 fine-tuned on CIFAR-10, with Grad-CAM visualizations for interpretability.
4
+
5
+ ## Features
6
+
7
+ - **Image Classification**: Classifies images into 10 categories (airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck).
8
+ - **Explainability**: Uses Grad-CAM to show which parts of the image influenced the prediction.
9
+ - **Interactive UI**: Built with Gradio for easy web-based interaction.
10
+
11
+ ## Installation
12
+
13
+ 1. Clone the repository:
14
+ ```bash
15
+ git clone https://github.com/your-username/xai-image-classifier.git
16
+ cd xai-image-classifier
17
+ ```
18
+
19
+ 2. Install dependencies:
20
+ ```bash
21
+ pip install -r requirements.txt
22
+ ```
23
+
24
+ 3. Download the model file `xai_resnet18.pth` and place it in the `model/` directory.
25
+
26
+ ## Usage
27
+
28
+ Run the app:
29
+ ```bash
30
+ python app.py
31
+ ```
32
+
33
+ Open the provided URL in your browser, upload an image, and click "Analyze Image" to get predictions and explanations.
34
+
35
+ ## Requirements
36
+
37
+ - Python 3.7+
38
+ - CUDA-compatible GPU (optional, for faster inference)
39
+ - Dependencies listed in `requirements.txt`
40
+
41
+ ## Model Details
42
+
43
+ - **Architecture**: ResNet18 with modified fully-connected layer.
44
+ - **Training Data**: CIFAR-10 dataset (60,000 images).
45
+ - **Pre-trained Weights**: None (trained from scratch).
46
+
47
+ ## License
48
+
49
+ MIT License
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models, transforms
4
+ from PIL import Image
5
+ import gradio as gr
6
+ from captum.attr import LayerGradCam
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ from io import BytesIO
10
+
11
+ # Configuration
12
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ CLASS_NAMES = ['airplane', 'automobile', 'bird', 'cat', 'deer',
14
+ 'dog', 'frog', 'horse', 'ship', 'truck']
15
+
16
+ # Load model
17
+ def load_model():
18
+ model = models.resnet18(weights=None)
19
+ model.fc = nn.Linear(512, 10)
20
+
21
+ checkpoint = torch.load('model/xai_resnet18.pth', map_location=DEVICE)
22
+ model.load_state_dict(checkpoint['model_state_dict'])
23
+ model = model.to(DEVICE)
24
+ model.eval()
25
+
26
+ return model
27
+
28
+ model = load_model()
29
+ target_layer = model.layer4[1].conv2
30
+ gradcam = LayerGradCam(model, target_layer)
31
+
32
+ # Image preprocessing
33
+ transform = transforms.Compose([
34
+ transforms.Resize((224, 224)),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
37
+ ])
38
+
39
+ # Prediction function
40
+ def predict_and_explain(image):
41
+ if image is None:
42
+ return "Please upload an image", None
43
+
44
+ # Preprocess
45
+ img_tensor = transform(image).unsqueeze(0).to(DEVICE)
46
+
47
+ # Predict
48
+ with torch.no_grad():
49
+ output = model(img_tensor)
50
+ probabilities = torch.softmax(output, dim=1)
51
+ pred_class = probabilities.argmax(1).item()
52
+ confidence = probabilities[0][pred_class].item()
53
+
54
+ # Generate Grad-CAM
55
+ attributions = gradcam.attribute(img_tensor, target=pred_class)
56
+ attr_np = attributions.squeeze().cpu().detach().numpy()
57
+ attr_np = (attr_np - attr_np.min()) / (attr_np.max() - attr_np.min() + 1e-8)
58
+
59
+ # Create visualization
60
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
61
+
62
+ axes[0].imshow(image)
63
+ axes[0].set_title("Original Image", fontsize=14, fontweight='bold')
64
+ axes[0].axis('off')
65
+
66
+ im = axes[1].imshow(attr_np, cmap='jet')
67
+ axes[1].set_title("Grad-CAM Heatmap", fontsize=14, fontweight='bold')
68
+ axes[1].axis('off')
69
+ plt.colorbar(im, ax=axes[1], fraction=0.046)
70
+
71
+ axes[2].imshow(image)
72
+ axes[2].imshow(attr_np, cmap='jet', alpha=0.5)
73
+ axes[2].set_title(f"Overlay\nPrediction: {CLASS_NAMES[pred_class]}",
74
+ fontsize=14, fontweight='bold')
75
+ axes[2].axis('off')
76
+
77
+ plt.tight_layout()
78
+
79
+ buf = BytesIO()
80
+ plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
81
+ buf.seek(0)
82
+ result_image = Image.open(buf)
83
+ plt.close(fig)
84
+
85
+ # Prediction text
86
+ prediction_text = f"**Prediction:** {CLASS_NAMES[pred_class]}\n\n"
87
+ prediction_text += f"**Confidence:** {confidence*100:.2f}%\n\n"
88
+ prediction_text += "**Top 3 Predictions:**\n"
89
+
90
+ top3_probs, top3_indices = torch.topk(probabilities[0], 3)
91
+ for prob, idx in zip(top3_probs, top3_indices):
92
+ prediction_text += f"- {CLASS_NAMES[idx]}: {prob.item()*100:.2f}%\n"
93
+
94
+ return prediction_text, result_image
95
+
96
+ # Gradio Interface
97
+ with gr.Blocks(title="πŸ” Explainable Image Classifier", theme=gr.themes.Soft()) as demo:
98
+ gr.Markdown("""
99
+ # πŸ” Explainable Image Classifier with Grad-CAM
100
+
101
+ Upload an image and see:
102
+ - **What** the AI predicts (classification)
103
+ - **Why** it made that decision (Grad-CAM visualization)
104
+
105
+ **Supported categories:** airplane, car, bird, cat, deer, dog, frog, horse, ship, truck
106
+ """)
107
+
108
+ with gr.Row():
109
+ with gr.Column():
110
+ input_image = gr.Image(type="pil", label="Upload Image")
111
+ predict_btn = gr.Button("πŸ” Analyze Image", variant="primary", size="lg")
112
+
113
+ with gr.Column():
114
+ output_text = gr.Markdown(label="Prediction Results")
115
+ output_image = gr.Image(label="Grad-CAM Visualization", type="pil")
116
+
117
+ predict_btn.click(
118
+ fn=predict_and_explain,
119
+ inputs=input_image,
120
+ outputs=[output_text, output_image]
121
+ )
122
+
123
+ gr.Markdown("""
124
+ ---
125
+ ### 🧠 About This Model
126
+ - **Architecture:** ResNet18 (transfer learning)
127
+ - **Training Data:** CIFAR-10 (60,000 images)
128
+ - **Explainability:** Grad-CAM visualization
129
+ """)
130
+
131
+ if __name__ == "__main__":
132
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ captum
5
+ Pillow
6
+ matplotlib
7
+ numpy