RFTSystems commited on
Commit
a8da7b7
·
verified ·
1 Parent(s): bcfe5bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -0
app.py CHANGED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from PIL import Image
6
+ import gradio as gr
7
+ import os # Import os to check for model file
8
+
9
+ # === Simple CNN Model Definition ===
10
+ class SimpleCNN(nn.Module):
11
+ def __init__(self):
12
+ super(SimpleCNN, self).__init__()
13
+ self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
14
+ self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
15
+ self.pool = nn.MaxPool2d(2, 2)
16
+ self.fc1 = nn.Linear(64 * 8 * 8, 512)
17
+ self.fc2 = nn.Linear(512, 10)
18
+
19
+ def forward(self, x):
20
+ x = self.pool(F.relu(self.conv1(x)))
21
+ x = self.pool(F.relu(self.conv2(x)))
22
+ x = x.view(-1, 64 * 8 * 8)
23
+ x = F.relu(self.fc1(x))
24
+ return self.fc2(x)
25
+
26
+ # === Model Loading ===
27
+ model = SimpleCNN()
28
+ model_path = 'simple_cnn_dclr_tuned.pth'
29
+
30
+ # Check if the model file exists before loading
31
+ if os.path.exists(model_path):
32
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
33
+ model.eval() # Set model to evaluation mode
34
+ print(f"Model loaded successfully from {model_path}")
35
+ else:
36
+ print(f"Warning: Model file '{model_path}' not found. Please ensure 'train_dclr_model.py' has been run.")
37
+ # Optionally, you might want to exit or raise an error if the model is crucial
38
+
39
+
40
+ # === CIFAR-10 Class Labels ===
41
+ class_labels = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
42
+
43
+ # === Image Preprocessing ===
44
+ preprocess = transforms.Compose([
45
+ transforms.Resize(32), # CIFAR-10 images are 32x32
46
+ transforms.ToTensor(),
47
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet stats are common
48
+ ])
49
+
50
+ # === Inference Function ===
51
+ def inference(input_image: Image.Image):
52
+ if model.training: # Ensure model is in eval mode
53
+ model.eval()
54
+
55
+ # Preprocess the image
56
+ processed_image = preprocess(input_image)
57
+ # Add a batch dimension
58
+ processed_image = processed_image.unsqueeze(0)
59
+
60
+ # Perform inference
61
+ with torch.no_grad():
62
+ outputs = model(processed_image)
63
+ probabilities = F.softmax(outputs, dim=1)
64
+
65
+ # Convert probabilities to a dictionary of class labels and scores
66
+ confidences = {class_labels[i]: float(probabilities[0, i]) for i in range(len(class_labels))}
67
+ return confidences
68
+
69
+ # === Gradio Interface Setup ===
70
+ # Example images (replace with actual paths if available, or keep as dummy for now)
71
+ # For a Hugging Face Space, you might place example images in an 'examples/' directory.
72
+ example_images = [
73
+ # os.path.join(os.path.dirname(__file__), "examples/example_car.png"),
74
+ # os.path.join(os.path.dirname(__file__), "examples/example_dog.png"),
75
+ # os.path.join(os.path.dirname(__file__), "examples/example_plane.png")
76
+ ]
77
+
78
+ # A placeholder for example images since we don't have them generated yet.
79
+ # Users can upload their own or I will add some placeholder images if needed in the next step.
80
+ # For now, an empty list of examples is fine.
81
+
82
+ interface = gr.Interface(
83
+ fn=inference,
84
+ inputs=gr.Image(type='pil', label='Input Image'),
85
+ outputs=gr.Label(num_top_classes=3, label='Predictions'),
86
+ title='CIFAR-10 Image Classification with DCLR Optimizer',
87
+ description='Upload an image and see the model\'s predictions using a SimpleCNN trained with the DCLR optimizer.',
88
+ examples=example_images,
89
+ allow_flagging='never'
90
+ )
91
+
92
+ # === Launch Gradio App ===
93
+ if __name__ == '__main__':
94
+ interface.launch()