Spaces:
Sleeping
Sleeping
| import torch | |
| import torchvision.transforms as transforms | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import gradio as gr | |
| import os # Import os to check for model file | |
| # === Simple CNN Model Definition === | |
| class SimpleCNN(nn.Module): | |
| def __init__(self): | |
| super(SimpleCNN, self).__init__() | |
| self.conv1 = nn.Conv2d(3, 32, 3, padding=1) | |
| self.conv2 = nn.Conv2d(32, 64, 3, padding=1) | |
| self.pool = nn.MaxPool2d(2, 2) | |
| self.fc1 = nn.Linear(64 * 8 * 8, 512) | |
| self.fc2 = nn.Linear(512, 10) | |
| def forward(self, x): | |
| x = self.pool(F.relu(self.conv1(x))) | |
| x = self.pool(F.relu(self.conv2(x))) | |
| x = x.view(-1, 64 * 8 * 8) | |
| x = F.relu(self.fc1(x)) | |
| return self.fc2(x) | |
| # === Model Loading === | |
| model = SimpleCNN() | |
| model_path = 'simple_cnn_dclr_tuned.pth' | |
| # Check if the model file exists before loading | |
| if os.path.exists(model_path): | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
| model.eval() # Set model to evaluation mode | |
| print(f"Model loaded successfully from {model_path}") | |
| else: | |
| print(f"Warning: Model file '{model_path}' not found. Please ensure 'train_dclr_model.py' has been run.") | |
| # Optionally, you might want to exit or raise an error if the model is crucial | |
| # === CIFAR-10 Class Labels === | |
| class_labels = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] | |
| # === Image Preprocessing === | |
| preprocess = transforms.Compose([ | |
| transforms.Resize(32), # CIFAR-10 images are 32x32 | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet stats are common | |
| ]) | |
| # === Inference Function === | |
| def inference(input_image: Image.Image): | |
| if model.training: # Ensure model is in eval mode | |
| model.eval() | |
| # Preprocess the image | |
| processed_image = preprocess(input_image) | |
| # Add a batch dimension | |
| processed_image = processed_image.unsqueeze(0) | |
| # Perform inference | |
| with torch.no_grad(): | |
| outputs = model(processed_image) | |
| probabilities = F.softmax(outputs, dim=1) | |
| # Convert probabilities to a dictionary of class labels and scores | |
| confidences = {class_labels[i]: float(probabilities[0, i]) for i in range(len(class_labels))} | |
| return confidences | |
| # === Gradio Interface Setup === | |
| # Example images (replace with actual paths if available, or keep as dummy for now) | |
| # For a Hugging Face Space, you might place example images in an 'examples/' directory. | |
| example_images = [ | |
| # os.path.join(os.path.dirname(__file__), "examples/example_car.png"), | |
| # os.path.join(os.path.dirname(__file__), "examples/example_dog.png"), | |
| # os.path.join(os.path.dirname(__file__), "examples/example_plane.png") | |
| ] | |
| # A placeholder for example images since we don't have them generated yet. | |
| # Users can upload their own or I will add some placeholder images if needed in the next step. | |
| # For now, an empty list of examples is fine. | |
| interface = gr.Interface( | |
| fn=inference, | |
| inputs=gr.Image(type='pil', label='Input Image'), | |
| outputs=gr.Label(num_top_classes=3, label='Predictions'), | |
| title='CIFAR-10 Image Classification with DCLR Optimizer', | |
| description='Upload an image and see the model\'s predictions using a SimpleCNN trained with the DCLR optimizer.', | |
| examples=example_images, | |
| allow_flagging='never' | |
| ) | |
| # === Launch Gradio App === | |
| if __name__ == '__main__': | |
| interface.launch() |