import torch import torch.nn as nn import torch.nn.functional as F import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader import matplotlib.pyplot as plt import os # Import the DCLR optimizer from the local file from dclr_optimizer import DCLR # === Simple CNN Model Definition === class SimpleCNN(nn.Module): def __init__(self): super(SimpleCNN, self).__init__() self.conv1 = nn.Conv2d(3, 32, 3, padding=1) self.conv2 = nn.Conv2d(32, 64, 3, padding=1) self.pool = nn.MaxPool2d(2, 2) self.fc1 = nn.Linear(64 * 8 * 8, 512) self.fc2 = nn.Linear(512, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 64 * 8 * 8) x = F.relu(self.fc1(x)) return self.fc2(x) # === CIFAR-10 Data Loading === transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) train_loader = DataLoader(train_set, batch_size=128, shuffle=True) test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) test_loader = DataLoader(test_set, batch_size=128, shuffle=False) # === Training Configuration === model = SimpleCNN() optimizer = DCLR(model.parameters(), lr=0.1, lambda_=0.1, verbose=False) criterion = nn.CrossEntropyLoss() epochs = 20 print(f"Starting training with DCLR for {epochs} epochs...") losses, accs = [], [] # === Training Loop === for epoch in range(epochs): model.train() running_loss = 0.0 correct = 0 total = 0 for inputs, labels in train_loader: optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step(output_activations=outputs) running_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() epoch_loss = running_loss / len(train_loader) epoch_acc = 100.0 * correct / total losses.append(epoch_loss) accs.append(epoch_acc) print(f"Epoch {epoch+1}/{epochs} - Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%") print("Training complete.") # === Evaluate on Test Set === model.eval() correct = 0 total = 0 with torch.no_grad(): for inputs, labels in test_loader: outputs = model(inputs) _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() test_acc = 100.0 * correct / total print(f"Final Test Accuracy: {test_acc:.2f}%") # === Save the Trained Model === torch.save(model.state_dict(), 'simple_cnn_dclr_tuned.pth') print("Model saved to simple_cnn_dclr_tuned.pth") # === Save Training Performance Plot === plt.figure() plt.plot(range(1, epochs+1), losses, label='Loss') plt.plot(range(1, epochs+1), accs, label='Accuracy') plt.xlabel('Epoch') plt.ylabel('Value') plt.legend() plt.title('Training Performance on CIFAR-10 (DCLR)') plt.savefig('training_performance.png') print("Training performance plot saved to training_performance.png") # === Save Final Test Accuracy Plot === plt.figure() plt.bar(['CIFAR-10'], [test_acc]) plt.ylabel('Accuracy (%)') plt.title('Final Test Accuracy (DCLR)') plt.savefig('final_test_accuracy.png') print("Final test accuracy plot saved to final_test_accuracy.png") # === Save Final Test Accuracy Number === with open("final_test_accuracy.txt", "w") as f: f.write(f"{test_acc:.2f}") print("Final test accuracy saved to final_test_accuracy.txt")